diff --git a/3rdparty/composable_kernel/.clang-format b/3rdparty/composable_kernel/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..22f26749668ad3dfb79cc76151847663d11da7d0 --- /dev/null +++ b/3rdparty/composable_kernel/.clang-format @@ -0,0 +1,90 @@ +--- +Language: Cpp +AccessModifierOffset: 0 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: true +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IndentCaseLabels: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Left +ReflowComments: true +SortIncludes: false +SpaceAfterCStyleCast: false +# SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: Never +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... + diff --git a/3rdparty/composable_kernel/.clang-tidy b/3rdparty/composable_kernel/.clang-tidy new file mode 100644 index 0000000000000000000000000000000000000000..5c2b78168747787b506b0cd6e72af595eb270670 --- /dev/null +++ b/3rdparty/composable_kernel/.clang-tidy @@ -0,0 +1,3 @@ +CheckOptions: + - key: bugprone-reserved-identifier.AllowedIdentifiers + value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' diff --git a/3rdparty/composable_kernel/.gitignore b/3rdparty/composable_kernel/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..71059ec4d948ed241b76ea7ebf9c79136cfcbf74 --- /dev/null +++ b/3rdparty/composable_kernel/.gitignore @@ -0,0 +1,49 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch +*.ipch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# vim tags +tags +.tags +.*.swp + +# Editors +.vscode + +# build-in-source directory +build* + +# emacs temporary/backup files +.\#* +\#*\# +*~ + +# GDB temporary files +.gdb_history +install.dir* diff --git a/3rdparty/composable_kernel/CITATION.cff b/3rdparty/composable_kernel/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..d35fe9e5870f0a538c1775efd239dc729e11ab88 --- /dev/null +++ b/3rdparty/composable_kernel/CITATION.cff @@ -0,0 +1,67 @@ +cff-version: 1.2.0 +title: Composable Kernel +message: If you use this software, please cite using the following metadata. +type: software +authors: + - given-names: Chao + family-names: Liu + email: chao.liu2@amd.com + affiliation: AMD + - given-names: Jing + family-names: Zhang + email: jing.zhang3@amd.com + affiliation: AMD + - given-names: Letao + family-names: Qin + email: letao.qin@amd.com + affiliation: AMD + - given-names: Qianfeng + family-names: Zhang + email: qianfeng.zhang@amd.com + affiliation: AMD + - given-names: Liang + family-names: Huang + email: carlus.huang@amd.com + affiliation: AMD + - given-names: Shaojie + family-names: Wang + email: shaojie.wang@amd.com + affiliation: AMD + - given-names: Anthony + family-names: Chang + email: antc@amd.com + affiliation: AMD + - given-names: Chunyu + family-names: Lai + email: chunyu.lai@amd.com + affiliation: AMD + - given-names: Illia + family-names: Silin + email: illia.silin@amd.com + affiliation: AMD + - given-names: Adam + family-names: Osewski + email: adam.osewski@amd.com + affiliation: AMD + - given-names: Poyen + family-names: Chen + email: poyen.chen@amd.com + affiliation: AMD + - given-names: Rosty + family-names: Geyyer + email: rosty.geyyer@amd.com + affiliation: AMD + - given-names: Hanwen + family-names: Chen + - given-names: Tejash + family-names: Shah + - given-names: Xiaoyan + family-names: Zhou + - given-names: Jianfeng + family-names: Yan +repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel' +abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++. +keywords: + - 'CK, Composable Kernel, Tensor Coordinate Transformation' +license: MIT +license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE diff --git a/3rdparty/composable_kernel/CMakeLists.txt b/3rdparty/composable_kernel/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..53e58890b8bdc0c96a2f1b801a40fc44d735b73a --- /dev/null +++ b/3rdparty/composable_kernel/CMakeLists.txt @@ -0,0 +1,311 @@ +cmake_minimum_required(VERSION 3.14) + +# Check support for CUDA/HIP in Cmake +project(composable_kernel) + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") + +enable_testing() +set(ROCM_PATH $ENV{ROCM_PATH}) +set(ROCM_SYMLINK_LIBS OFF) +find_package(ROCM REQUIRED PATHS ${ROCM_PATH}) + +include(ROCMInstallTargets) +include(ROCMPackageConfigHelpers) +include(ROCMSetupVersion) +include(ROCMInstallSymlinks) +include(ROCMCreatePackage) +include(CheckCXXCompilerFlag) + +rocm_setup_version(VERSION 0.2.0) +include(TargetFlags) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip ${ROCM_PATH} ${ROCM_PATH}/llvm ${ROCM_PATH}/hip) + +option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) + +if(USE_BITINT_EXTENSION_INT4) + add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + add_compile_options(-Wno-bit-int-extension) + message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") +endif() + +## Threads +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +link_libraries(Threads::Threads) + +## C++ +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") + +## OpenMP +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # workaround issue hipcc in rocm3.5 cannot find openmp + set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") + set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") + set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") + set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) +else() + find_package(OpenMP REQUIRED) +endif() + +message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") + +link_libraries(${OpenMP_gomp_LIBRARY}) +link_libraries(${OpenMP_pthread_LIBRARY}) + +## HIP +find_package(HIP REQUIRED) +# Override HIP version in config.h, if necessary. +# The variables set by find_package() can't be overwritten, +# therefore let's use intermediate variables. +set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") +set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}") +set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") +if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) + set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") + message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) + set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}") + message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) + set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") + message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") +endif() +message(STATUS "Build with HIP ${HIP_VERSION}") +link_libraries(hip::device) +add_compile_definitions(__HIP_PLATFORM_HCC__=1) + +## tidy +include(EnableCompilerWarnings) +set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) +if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") + set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +# Enable tidy on hip +elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") + set(CK_TIDY_ERRORS ALL) +endif() + + +include(ClangTidy) +enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DCK_USE_CLANG_TIDY +) + +include(CppCheck) +enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + library/src + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include + DEFINE + CPPCHECK=1 + __linux__=1 +) + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) + +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include + ${HIP_INCLUDE_DIRS} +) + + +SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +if(BUILD_DEV) + add_compile_options(-Werror) + add_compile_options(-Weverything) +endif() +message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + +file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") +file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) +set(CK_DEVICE_INSTANCES) +FOREACH(subdir_path ${dir_list}) + IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") + list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) + ENDIF() +ENDFOREACH() +add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) + +rocm_package_setup_component(tests + LIBRARY_NAME composablekernel + PACKAGE_NAME tests # Prevent -static suffix on package name +) + +rocm_package_setup_component(examples + LIBRARY_NAME composablekernel + PACKAGE_NAME examples +) + +rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckProfiler +) + +add_subdirectory(library) +add_subdirectory(example) +add_subdirectory(test) +add_subdirectory(profiler) + +#Create an interface target for the include only files and call it "composablekernels" +include(CMakePackageConfigHelpers) + +set(version 1.0.0) +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + VERSION "${version}" + COMPATIBILITY AnyNewerVersion +) + +configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO +) + +rocm_install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) + +set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") +set(CPACK_RPM_PACKAGE_LICENSE "MIT") + +rocm_create_package( + NAME composablekernel + DESCRIPTION "High Performance Composable Kernel for AMD GPUs" + MAINTAINER "MIOpen Kernels Dev Team " + LDCONFIG + HEADER_ONLY +) diff --git a/3rdparty/composable_kernel/CONTRIBUTORS.md b/3rdparty/composable_kernel/CONTRIBUTORS.md new file mode 100644 index 0000000000000000000000000000000000000000..8ccfe99c3cc73b643f8b92cb654005e54c0774bd --- /dev/null +++ b/3rdparty/composable_kernel/CONTRIBUTORS.md @@ -0,0 +1,31 @@ +# Composable Kernel Developers and Contributors + +This is the list of developers and contributors to Composable Kernel library + + +## Developers +[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2022 + +[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022 + +[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022 + +Hanwen Chang, 2019-2021, + +Tejash Shah, 2019-2020 + +Xiaoyan Zhou, 2020 + +[Jianfeng Yan](https://github.com/j4yan), 2021-2022 + + +## Product Manager +[Jun Liu](https://github.com/junliume) + + +## Contributors +[Dan Yao](https://github.com/danyao12), [Guangzhao Lu](https://github.com/guangzlu), [Raman Jana](https://github.com/ramjana), [Jehandad Khan](https://github.com/JehandadKhan), [Wen-Heng (Jack) Chung](https://github.com/whchung) + + +## Acknowledgement +CK team works closely with Meta [AITemplate](https://github.com/facebookincubator/AITemplate) team ([Bing Xu](https://github.com/antinucleon), [Hao Lu](https://github.com/hlu1), [Ying Zhang](https://github.com/ipiszy), etc). Most of the lucrative graph optimization opportunities in ML models were identified by AITemplate team, and we also co-designed many high performance fused kernels for AMD GPUs. Without this collaboration, CK would not reach its current potential. diff --git a/3rdparty/composable_kernel/Config.cmake.in b/3rdparty/composable_kernel/Config.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..02978cd4dd49f53188080f1a0404cf87bfd52d7b --- /dev/null +++ b/3rdparty/composable_kernel/Config.cmake.in @@ -0,0 +1,11 @@ +@PACKAGE_INIT@ + +set(_composable_kernel_supported_components device_operations utility) + +foreach(_comp ${composable_kernel_FIND_COMPONENTS}) + if(NOT _comp IN_LIST _composable_kernel_supported_components) + set(composable_kernel_FOUND False) + set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() + include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") +endforeach() diff --git a/3rdparty/composable_kernel/Dockerfile b/3rdparty/composable_kernel/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d024f966c57fac8ced14abbb00db7cc8f20c7d63 --- /dev/null +++ b/3rdparty/composable_kernel/Dockerfile @@ -0,0 +1,110 @@ +FROM ubuntu:20.04 + +ARG ROCMVERSION=5.3 +ARG compiler_version="release" +ARG compiler_commit="" + +RUN set -xe + +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +# Add rocm repository +RUN apt-get update +RUN apt-get install -y wget gnupg +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - +RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" +RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" + +# Install dependencies +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + build-essential \ + ccache \ + cmake-data \ + cmake \ + curl \ + git \ + hip-rocclr \ + jq \ + libelf-dev \ + libncurses5-dev \ + libnuma-dev \ + libpthread-stubs0-dev \ + llvm-amdgpu \ + pkg-config \ + python \ + python3 \ + python-dev \ + python3-dev \ + python3-pip \ + software-properties-common \ + rocm-dev \ + rocm-device-libs \ + rocm-cmake \ + vim \ + zlib1g-dev \ + openssh-server \ + clang-format-10 \ + kmod && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Setup ubsan environment to printstacktrace +RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer +ENV UBSAN_OPTIONS=print_stacktrace=1 + +# Install an init system +RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb +RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb + +ARG PREFIX=/opt/rocm +# Install packages for processing the performance results +RUN pip3 install --upgrade pip +RUN pip3 install sqlalchemy +RUN pip3 install pymysql +RUN pip3 install pandas +RUN pip3 install setuptools-rust +RUN pip3 install sshtunnel +# Setup ubsan environment to printstacktrace +ENV UBSAN_OPTIONS=print_stacktrace=1 + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 +RUN groupadd -f render + +# Install the new rocm-cmake version +RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / + +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" +RUN sh -c "echo compiler commit = '$compiler_commit'" + +RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ]; then \ + sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/hip/bin/hipcc.pl && \ + sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/bin/hipcc.pl; \ + fi + +RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + +RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + + +#ENV HIP_CLANG_PATH='/llvm-project/build/bin' +#RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/3rdparty/composable_kernel/Jenkinsfile b/3rdparty/composable_kernel/Jenkinsfile new file mode 100644 index 0000000000000000000000000000000000000000..7b2e57c1403df3cd4574bf127b96163f1b736d17 --- /dev/null +++ b/3rdparty/composable_kernel/Jenkinsfile @@ -0,0 +1,705 @@ +def rocmnode(name) { + return 'rocmtest && miopen && ' + name +} + +def show_node_info() { + sh """ + echo "NODE_NAME = \$NODE_NAME" + lsb_release -sd + uname -r + ls /opt/ -la + """ +} + +def runShell(String command){ + def responseCode = sh returnStatus: true, script: "${command} > tmp.txt" + def output = readFile(file: "tmp.txt") + echo "tmp.txt contents: $output" + return (output != "") +} + +def getDockerImageName(){ + def img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + return img +} + +def check_host() { + if ("${env.CK_CCACHE}" != "null"){ + def CCACHE_SERVER="${env.CK_CCACHE.split(':')[0]}" + echo "ccache server: ${CCACHE_SERVER}" + sh '''ping -c 1 -p 6379 "${CCACHE_SERVER}" | echo $? > tmp.txt''' + def output = readFile(file: "tmp.txt") + echo "tmp.txt contents: \$output" + return (output != "0") + } + else{ + return 1 + } +} + +def build_compiler(){ + def compiler + if (params.BUILD_COMPILER == "hipcc"){ + compiler = '/opt/rocm/bin/hipcc' + } + else{ + if (params.COMPILER_VERSION == "release"){ + compiler = "/opt/rocm/llvm/bin/clang++" + } + else{ + compiler = "/llvm-project/build/bin/clang++" + } + } + return compiler +} + +def getDockerImage(Map conf=[:]){ + env.DOCKER_BUILDKIT=1 + def prefixpath = conf.get("prefixpath", "/opt/rocm") // prefix:/opt/rocm + def no_cache = conf.get("no_cache", false) + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + echo "ccache server: ${env.CK_CCACHE}" + if(env.CK_CCACHE) + { + if(check_host()) + { + echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" + } + else + { + echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" + } + dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " + env.CCACHE_DIR = """/tmp/ccache_store""" + env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" + } + if(no_cache) + { + dockerArgs = dockerArgs + " --no-cache " + } + echo "Docker Args: ${dockerArgs}" + def image = getDockerImageName() + //Check if image exists + def retimage + try + { + echo "Pulling down image: ${image}" + retimage = docker.image("${image}") + retimage.pull() + } + catch(Exception ex) + { + error "Unable to locate image: ${image}" + } + return [retimage, image] +} + +def buildDocker(install_prefix){ + show_node_info() + env.DOCKER_BUILDKIT=1 + checkout scm + def image_name = getDockerImageName() + echo "Building Docker for ${image_name}" + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + echo "ccache server: ${env.CK_CCACHE}" + if(env.CK_CCACHE) + { + if(check_host()) + { + echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" + } + else + { + echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" + } + dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " + env.CCACHE_DIR = """/tmp/ccache_store""" + env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" + } + + echo "Build Args: ${dockerArgs}" + try{ + if(params.BUILD_DOCKER){ + //force building the new docker if that parameter is true + echo "Building image: ${image_name}" + retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage.push() + } + else{ + echo "Checking for image: ${image_name}" + sh "docker manifest inspect --insecure ${image_name}" + echo "Image: ${image_name} found!! Skipping building image" + } + } + catch(Exception ex){ + echo "Unable to locate image: ${image_name}. Building image now" + retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage.push() + } +} + +def cmake_build(Map conf=[:]){ + + def compiler = build_compiler() + def config_targets = conf.get("config_targets","check") + def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined " + conf.get("extradebugflags", "") + def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") + def prefixpath = conf.get("prefixpath","/opt/rocm") + def setup_args = conf.get("setup_args","") + + if (prefixpath != "/usr/local"){ + setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " + } + + def build_type_debug = (conf.get("build_type",'release') == 'debug') + + //cmake_env can overwrite default CXX variables. + def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") + + def package_build = (conf.get("package_build","") == "true") + + if (package_build == true) { + config_targets = "package" + } + + if(conf.get("build_install","") == "true") + { + config_targets = 'install ' + config_targets + setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args + } else{ + setup_args = ' -DBUILD_DEV=On' + setup_args + } + + if(build_type_debug){ + setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args + }else{ + setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args + } + if(env.CK_CCACHE) + { + setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER='ccache' -DCMAKE_C_COMPILER_LAUNCHER='ccache' " + setup_args + } + echo "ccache server: ${env.CK_CCACHE}" + + def pre_setup_cmd = """ + echo \$HSA_ENABLE_SDMA + ulimit -c unlimited + rm -rf build + mkdir build + rm -rf install + mkdir install + cd build + """ + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + // reduce parallelism when compiling, clang uses too much memory + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 2 )) ${config_targets}") + def execute_cmd = conf.get("execute_cmd", "") + + def cmd = conf.get("cmd", """ + ${pre_setup_cmd} + ${setup_cmd} + ${build_cmd} + ${execute_cmd} + """) + + echo cmd + sh cmd + + // Only archive from master or develop + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true + } +} + +def buildHipClangJob(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = getDockerImageName() + def prefixpath = conf.get("prefixpath", "/opt/rocm") + + // Jenkins is complaining about the render group + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1 " + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if (params.COMPILER_VERSION != "release"){ + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + } + + def variant = env.STAGE_NAME + + def retimage + (retimage, image) = getDockerImage(conf) + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + cmake_build(conf) + } + } + } + return retimage +} + +def reboot(){ + build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] +} + +def buildHipClangJobAndReboot(Map conf=[:]){ + try{ + buildHipClangJob(conf) + } + catch(e){ + echo "throwing error exception for the stage" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + +def runCKProfiler(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = getDockerImageName() + def prefixpath = conf.get("prefixpath", "/opt/rocm") + + // Jenkins is complaining about the render group + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1 " + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if (params.COMPILER_VERSION != "release"){ + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + } + + def variant = env.STAGE_NAME + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + try { + (retimage, image) = getDockerImage(conf) + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES'){ + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' + if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + throw new Exception ("GPU not found") + } + else{ + echo "GPU is OK" + } + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + " --no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES'){ + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' + if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + throw new Exception ("GPU not found") + } + else{ + echo "GPU is OK" + } + } + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 24, unit: 'HOURS') + { + //cmake_build(conf) + //instead of building, just unstash the ckProfiler and install it + sh """ + rm -rf build + mkdir build + """ + dir("build"){ + unstash 'ckProfiler.tar.gz' + sh 'tar -xvf ckProfiler.tar.gz' + } + + dir("script"){ + if (params.RUN_FULL_QA){ + sh "./run_full_performance_tests.sh 1 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + archiveArtifacts "perf_gemm.log" + archiveArtifacts "perf_resnet50_N256.log" + archiveArtifacts "perf_resnet50_N4.log" + archiveArtifacts "perf_batched_gemm.log" + archiveArtifacts "perf_grouped_gemm.log" + archiveArtifacts "perf_conv_fwd.log" + archiveArtifacts "perf_conv_bwd_data.log" + archiveArtifacts "perf_gemm_bilinear.log" + archiveArtifacts "perf_reduction.log" + archiveArtifacts "perf_splitK_gemm_verify.log" + archiveArtifacts "perf_splitK_gemm.log" + archiveArtifacts "perf_onnx_gemm.log" + // stash perf files to master + stash name: "perf_gemm.log" + stash name: "perf_resnet50_N256.log" + stash name: "perf_resnet50_N4.log" + stash name: "perf_batched_gemm.log" + stash name: "perf_grouped_gemm.log" + stash name: "perf_conv_fwd.log" + stash name: "perf_conv_bwd_data.log" + stash name: "perf_gemm_bilinear.log" + stash name: "perf_reduction.log" + stash name: "perf_splitK_gemm.log" + stash name: "perf_onnx_gemm.log" + //we will process results on the master node + } + else{ + sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + archiveArtifacts "perf_gemm.log" + archiveArtifacts "perf_resnet50_N256.log" + archiveArtifacts "perf_resnet50_N4.log" + // stash perf files to master + stash name: "perf_gemm.log" + stash name: "perf_resnet50_N256.log" + stash name: "perf_resnet50_N4.log" + //we will process the results on the master node + } + } + } + } + } + return retimage +} + +def runPerfTest(Map conf=[:]){ + try{ + runCKProfiler(conf) + } + catch(e){ + echo "throwing error exception in performance tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + +def Build_CK(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = getDockerImageName() + def prefixpath = conf.get("prefixpath", "/opt/rocm") + + // Jenkins is complaining about the render group + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1 " + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if (params.COMPILER_VERSION != "release"){ + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + } + + def variant = env.STAGE_NAME + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + try { + (retimage, image) = getDockerImage(conf) + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES'){ + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' + if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + throw new Exception ("GPU not found") + } + else{ + echo "GPU is OK" + } + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + " --no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES'){ + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log' + if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + throw new Exception ("GPU not found") + } + else{ + echo "GPU is OK" + } + } + } + } + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 24, unit: 'HOURS') + { + cmake_build(conf) + dir("build"){ + //run tests and examples + sh 'make -j check' + //we only need the ckProfiler to run the performance tests, so we pack and stash it + sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' + stash "ckProfiler.tar.gz" + } + } + } + } + return retimage +} + +def Build_CK_and_Reboot(Map conf=[:]){ + try{ + Build_CK(conf) + } + catch(e){ + echo "throwing error exception while building CK" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + +def process_results(Map conf=[:]){ + env.HSA_ENABLE_SDMA=0 + checkout scm + def image = getDockerImageName() + def prefixpath = "/opt/rocm" + + // Jenkins is complaining about the render group + def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1 " + } + + def variant = env.STAGE_NAME + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + try { + (retimage, image) = getDockerImage(conf) + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 1, unit: 'HOURS'){ + try{ + dir("script"){ + if (params.RUN_FULL_QA){ + // unstash perf files to master + unstash "perf_gemm.log" + unstash "perf_resnet50_N256.log" + unstash "perf_resnet50_N4.log" + unstash "perf_batched_gemm.log" + unstash "perf_grouped_gemm.log" + unstash "perf_conv_fwd.log" + unstash "perf_conv_bwd_data.log" + unstash "perf_gemm_bilinear.log" + unstash "perf_reduction.log" + unstash "perf_splitK_gemm.log" + unstash "perf_onnx_gemm.log" + sh "./process_qa_data.sh" + } + else{ + // unstash perf files to master + unstash "perf_gemm.log" + unstash "perf_resnet50_N256.log" + unstash "perf_resnet50_N4.log" + sh "./process_perf_data.sh" + } + } + } + catch(e){ + echo "throwing error exception while processing performance test results" + echo 'Exception occurred: ' + e.toString() + throw e + } + } + } +} + +//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;COMPILER_VERSION=release + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open''' : "" + +pipeline { + agent none + triggers { + parameterizedCron(CRON_SETTINGS) + } + options { + parallelsAlwaysFailFast() + } + parameters { + booleanParam( + name: "BUILD_DOCKER", + defaultValue: false, + description: "Force building docker image (default: false), set to true if docker image needs to be updated.") + string( + name: 'ROCMVERSION', + defaultValue: '5.3', + description: 'Specify which ROCM version to use: 5.2.3, or 5.3 (default), etc.') + string( + name: 'COMPILER_VERSION', + defaultValue: 'release', + description: 'Specify which version of compiler to use: ck-9110, release (default), or amd-stg-open.') + string( + name: 'COMPILER_COMMIT', + defaultValue: '', + description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit (default), or use 8a82e4eb7ba28521ba9a9424a0315a8a16590424 commit of amd-stg-open branch.') + string( + name: 'BUILD_COMPILER', + defaultValue: 'hipcc', + description: 'Specify whether to build CK with hipcc (default) or with clang.') + booleanParam( + name: "RUN_FULL_QA", + defaultValue: false, + description: "Select whether to run small set of performance tests (default) or full QA") + } + environment{ + dbuser = "${dbuser}" + dbpassword = "${dbpassword}" + dbsship = "${dbsship}" + dbsshport = "${dbsshport}" + dbsshuser = "${dbsshuser}" + dbsshpassword = "${dbsshpassword}" + status_wrapper_creds = "${status_wrapper_creds}" + gerrit_cred="${gerrit_cred}" + DOCKER_BUILDKIT = "1" + } + stages{ + stage("Build Docker"){ + //when { + // beforeAgent true + // expression { params.BUILD_DOCKER.toBoolean() } + //} + parallel{ + stage('Docker /opt/rocm'){ + agent{ label rocmnode("nogpu") } + steps{ + buildDocker('/opt/rocm') + } + } + } + } + stage("Static checks") { + parallel{ + stage('Clang Format') { + agent{ label rocmnode("nogpu") } + environment{ + execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ + -o -not -path \'*.git*\' -iname \'*.hpp\' \ + -o -not -path \'*.git*\' -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + } + } + } + } + + stage("Build CK and run Tests") + { + parallel + { + stage("Build CK and run Tests") + { + agent{ label rocmnode("gfx908 || gfx90a") } + environment{ + setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" """ : """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 " """ }" + execute_args = "${params.COMPILER_VERSION == "ck-9110" ? """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ : """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -DCMAKE_CXX_FLAGS="-O3" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ }" + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + } + } + + stage("Performance Tests") + { + parallel + { + stage("Run ckProfiler: gfx908 or gfx90a") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } + options { retry(2) } + agent{ label rocmnode("gfx908 || gfx90a")} + environment{ + setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + } + } + stage("Run ckProfiler: gfx90a") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() } + } + options { retry(2) } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + } + } + } + } + stage("Process Performance Test Results") + { + parallel + { + stage("Process results"){ + agent { label 'mici' } + steps{ + process_results() + } + } + } + } + } +} diff --git a/3rdparty/composable_kernel/LICENSE b/3rdparty/composable_kernel/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2fe9a8455efaeda2eab474b2aa038ec2d9e76841 --- /dev/null +++ b/3rdparty/composable_kernel/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/3rdparty/composable_kernel/README.md b/3rdparty/composable_kernel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..129996e268033831a2d29cc96d16edaed2805caa --- /dev/null +++ b/3rdparty/composable_kernel/README.md @@ -0,0 +1,96 @@ +# Composable Kernel + +## Methodology +Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". + +![ALT](/doc/image/ck_component.png "CK Components") + +## Code Structure +Current CK library are structured into 4 layers: +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +![ALT](/doc/image/ck_layer.png "CK Layers") + +## Contributors +The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) + +## Citation +If you use CK, please use following citations: +* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) +* [CITATION.cff](/CITATION.cff) + +## License +CK is released under the MIT license. [License File](/LICENSE) + + +# Build CK + +## Build docker image +```bash +DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . +``` + +## Launch docker +```bash +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +ck:latest \ +/bin/bash +``` + +## Build CK +```bash +mkdir build && cd build + +# Need to specify target ID, example below is for gfx908 and gfx90a +cmake \ +-D CMAKE_PREFIX_PATH=/opt/dtk-23.04 \ +-D CMAKE_CXX_COMPILER=/opt/dtk-23.04/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +-D GPU_TARGETS="gfx906;gfx926" \ +.. +``` + +### Build examples and tests +```bash + make -j examples tests + make test +``` + +Instructions for running each individual examples are under [example](/example) + + +## Build ckProfiler +```bash + make -j ckProfiler +``` +Instructions for running ckProfiler are under [profiler](/profiler) + +## Install CK +```bash +make install +``` + +## Using CK as pre-built kernel library +Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) + +## Caveat +### Kernel Timing and Verification +CK's own kernel timer will warn up kernel once, and then run it multiple times +to get average kernel time. For some kernels that use atomic add, this will cause +output buffer to be accumulated multiple times, causing verification failure. +To work around it, do not use CK's own timer and do verification at the same time. +CK's own timer and verification in each example and ckProfiler can be enabled or +disabled from command line. diff --git a/3rdparty/composable_kernel/client_example/01_gemm/CMakeLists.txt b/3rdparty/composable_kernel/client_example/01_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e741192f90b8216e4b3abe32ae8971fb45ddfee --- /dev/null +++ b/3rdparty/composable_kernel/client_example/01_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm gemm.cpp) +target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/01_gemm/gemm.cpp b/3rdparty/composable_kernel/client_example/01_gemm/gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..668c1bf0041847f3e0a9842a693bb8c1c8e4811e --- /dev/null +++ b/3rdparty/composable_kernel/client_example/01_gemm/gemm.cpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = F32; +using BDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + ADataType *in_data=new ADataType[K*N]; + for(int i=0;i; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << op_name << "support this problem" << std::endl; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + // std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + hipMemcpy(in_data,c_device_buf.GetDeviceBuffer(),sizeof(CDataType) * f_matrix_space_size(M, N, StrideA, CLayout{}),hipMemcpyDeviceToHost); + for(int i=0;i<100;i++){ + std::cout< +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..512555f978efc6932e38ff31ec20ebf3aab4b063 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72372310321e297b3a0d9101c808eda273675dde --- /dev/null +++ b/3rdparty/composable_kernel/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = FastGelu; + +using ADataType = F16; +using BDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::FastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/03_gemm_layernorm/CMakeLists.txt b/3rdparty/composable_kernel/client_example/03_gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3742e70844b96575e263b22a14b0bb8c4cde7a43 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/03_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp) +target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/3rdparty/composable_kernel/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c259407d4608f90bd331ee9a9686b56ad62de90 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using BiasDataType = F32; +using CDataType = F16; +using D0DataType = F16; +using ReduceDataType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op, + const void* p_a, + const void* p_b, + const void* p_bias, + const void* p_d0, + void* p_c, + void* p_mean, + void* p_square_mean, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0, + bool time_kernel) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto passOp = PassThrough{}; + auto squareOp = UnarySquareElementOp{}; + auto divOp = UnaryDivElementOp{N}; + + auto argument_ptr = + p_op->MakeArgumentPointer(p_a, + p_b, + p_bias, + {p_d0}, + p_c, + {p_mean, p_square_mean}, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + {&passOp, &passOp, &passOp}, // functor for a, b, c + {&passOp}, // functor for d0 + {&passOp, &squareOp}, // functor for inputs of reduction + {&divOp, &divOp}); // functor for outputs of reduction + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + + // If we evaluate running time of gemm_reduce. The output may wrong. + // Because we need to initialize the reduction tensor before runing the kernel. + // However we run kernel many times for time_kernel = trie without reinitialize the out + // of reduction tensor. + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +template +bool RunDeviceNormalize2D(normalize_op_ptr& p_op, + const void* p_x, + const void* p_mean, + const void* p_square_mean, + const void* p_gamma, + const void* p_beta, + void* p_y, + int M, + int N, + int StrideX, + bool time_kernel) +{ + std::array input = {p_x, p_mean, p_square_mean, p_gamma, p_beta}; + std::array output = {p_y}; + auto normalize_functor = ck::tensor_operation::element_wise::Normalize{}; + + std::array xyLengths = {M, N}; + std::array xyStrides = {StrideX, 1}; + + auto argument_ptr = p_op->MakeArgumentPointer(xyLengths, + {xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}}, + {xyStrides}, + input, + output, + ck::tensor_operation::element_wise::Normalize{}); + + if(p_op->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = p_op->MakeInvokerPointer(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl; + + return true; + } + + return false; +} + +int main() +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideD0 = 1024; + + const auto gemm_reduce_ptrs = + ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances< + ADataType, + BDataType, + CDataType, + ALayout, + BLayout, + CLayout>(); + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< + CDataType, + ReduceDataType, + ReduceDataType, + GammaDataType, + BetaDataType, + LayerNormOutDataType>(); + + std::cout << "found " << gemm_reduce_ptrs.size() + << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, CLayout{})); + SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N); + + bool b_time_kernel = true; + bool b_only_run_first_kernel = true; + + // layernorm => (1) + (2) + // (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c) + // (2). normalize(c, mean, square_mean, gamma, beta) + for(auto& gemm_reduce_ptr : gemm_reduce_ptrs) + { + // run first available kernel + if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr, + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + d0_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideD0, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + + for(auto& normalize_ptr : normalize_ptrs) + { + if(RunDeviceNormalize2D(normalize_ptr, + c_device_buf.GetDeviceBuffer(), + reduceMean_device_buf.GetDeviceBuffer(), + reduceMeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + layerNorm_device_buf.GetDeviceBuffer(), + M, + N, + StrideC, + b_time_kernel)) + { + if(b_only_run_first_kernel) + break; + } + else + { + std::cout << normalize_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } +} diff --git a/3rdparty/composable_kernel/client_example/04_contraction/CMakeLists.txt b/3rdparty/composable_kernel/client_example/04_contraction/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4bc6780f96d2fe4a4912e3c188b4b5155cc162dd --- /dev/null +++ b/3rdparty/composable_kernel/client_example/04_contraction/CMakeLists.txt @@ -0,0 +1,6 @@ +add_executable(client_contraction_scale contraction_scale.cpp) +target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_bilinear contraction_bilinear.cpp) +target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) + diff --git a/3rdparty/composable_kernel/client_example/04_contraction/contraction_bilinear.cpp b/3rdparty/composable_kernel/client_example/04_contraction/contraction_bilinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91dead41a4cac19db857b99a233839e9e6647c57 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/04_contraction/contraction_bilinear.cpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" +#include "ck/library/utility/numeric.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Bilinear; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 25) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])}; + + alpha = std::stof(argv[23]); + beta = std::stof(argv[24]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg23 to 24: alpha, beta\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem d_device_buf(sizeof(DDataType) * + f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/04_contraction/contraction_scale.cpp b/3rdparty/composable_kernel/client_example/04_contraction/contraction_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e08ee19cdb098b2dfb70a662d59c87008400123 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/04_contraction/contraction_scale.cpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" +#include "ck/library/utility/numeric.hpp" + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Scale; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 20) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + scale = std::stof(argv[19]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg19: scale\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{scale}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/05_layernorm/CMakeLists.txt b/3rdparty/composable_kernel/client_example/05_layernorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b582b485d4ce46951aaed98c6256e1c997388bd3 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/05_layernorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_layernorm2d layernorm2d.cpp) +target_link_libraries(client_layernorm2d PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/05_layernorm/layernorm2d.cpp b/3rdparty/composable_kernel/client_example/05_layernorm/layernorm2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adb41171e12a87ffafe42e4f112a3e89cfc7296e --- /dev/null +++ b/3rdparty/composable_kernel/client_example/05_layernorm/layernorm2d.cpp @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using AccDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto xy_size = (M - 1) * Stride + N; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalization; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {Stride, 1}, // xStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {Stride, 1}, // xStrides + {1}, // gammaStrides + {1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/06_softmax/CMakeLists.txt b/3rdparty/composable_kernel/client_example/06_softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b38a0fd9e27570e62df7f701572a24fb4ee842f9 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/06_softmax/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_softmax4d softmax4d.cpp) +target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/06_softmax/softmax4d.cpp b/3rdparty/composable_kernel/client_example/06_softmax/softmax4d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7745ddf34cf7abb65265ada74f0896dedcbbf655 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/06_softmax/softmax4d.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/softmax.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 4; +constexpr int NumReduceDim = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::vector in_lengths{2, 8, 128, 1024}; + std::vector in_strides{8 * 128 * 1024, 128 * 1024, 1024, 1}; + std::vector reduce_dims{2, 3}; + + ck::index_t num_elements = + std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies()); + + AccDataType alpha{2.0f}; + AccDataType beta{2.0f}; + + SimpleDeviceMem in(sizeof(InDataType) * num_elements); + SimpleDeviceMem out(sizeof(OutDataType) * num_elements); + + using DeviceOp = ck::tensor_operation::device:: + DeviceSoftmax; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim) + { + continue; + } + + auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + &alpha, + &beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = num_elements * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * num_elements * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + &alpha, + &beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ddc83168acfbbdf3d58b0909b761947c792a3c06 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) +target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/grouped_conv2d_fwd.cpp b/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/grouped_conv2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c119c78ab6afc09e0c30f6ade6a52ae9ecb139ea --- /dev/null +++ b/3rdparty/composable_kernel/client_example/07_grouped_conv2d_fwd/grouped_conv2d_fwd.cpp @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +// using InDataType = ck::half_t; +// using WeiDataType = ck::half_t; +// using OutDataType = ck::half_t; +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 26; +static constexpr ck::index_t Wo = 26; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void printArray(std::array array) +{ + for(int i=0;i in_lengths{G, N, Hi, Wi, C}; + std::array in_strides{0, 0, 0, 0, 1}; + + std::array wei_lengths{G, K, Y, X, C}; + std::array wei_strides{0, 0, 0, 0, 1}; + + std::array out_lengths{G, N, Ho, Wo, K}; + std::array out_strides{0, 0, 0, 0, 1}; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + printArray(in_lengths); + printArray(in_strides); + + // transpose GNHWC/GKYXC/GNHWK to GNCHW/GKCYX/GNCHW + std::rotate( + rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 3)); + std::rotate( + rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 3)); + std::rotate( + rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 3)); + std::rotate( + rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 3)); + std::rotate( + rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 3)); + std::rotate( + rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 3)); + + printArray(in_lengths); + printArray(in_strides); + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{0, 0}; + std::array input_right_pads{0, 0}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + + InDataType *in_data=new InDataType[G * N * Hi * Wi * C]; + for(int i=0;i, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << op_name << std::endl; + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * G * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + //std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + //<< gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + //std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + hipMemcpy(in_data,out.GetDeviceBuffer(),sizeof(OutDataType) * G * N * Ho * Wo * K,hipMemcpyDeviceToHost); + for(int i=0;i<10;i++){ + std::cout< +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +constexpr static auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +using ADataType = ck::half_t; +using B0DataType = ck::half_t; +using B1DataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + int G0 = 48; + int G1 = 16; + int M = 1024; + int N = 1024; + int K = 64; + int O = 64; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O); + + using DeviceOp = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + MaskingSpec>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + {}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + AElementOp{}, + B0ElementOp{}, + Acc0ElementOp{1 / sqrtf(K)}, + B1ElementOp{}, + CElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + G0 * G1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best instance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + {}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + AElementOp{}, + B0ElementOp{}, + Acc0ElementOp{1 / sqrtf(K)}, + B1ElementOp{}, + CElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/09_quantization/CMakeLists.txt b/3rdparty/composable_kernel/client_example/09_quantization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7dc9b860c0c427a21e6127fcce5556dbb06089e9 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/09_quantization/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) + +add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) + +add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations) + +add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bcb0cefa712cf144edf1916adf0c0f97515d56f0 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..26c7aa15e2be0814b20b17f9c2c91fe21e70d961 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..475b2f03b4f558ad4bc319393472b926ebbfee2d --- /dev/null +++ b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {requant_scale_lengths}, + {requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da7b7e6abffd1f5033e4a02700a1a939a950dc0e --- /dev/null +++ b/3rdparty/composable_kernel/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e564f3180d8c9295356d90266f2159de47aac060 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55c789804230ccccf66d68be9244c5c4111451e6 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, Hi, Wi, C}; + std::array in_strides{0, 0, 0, 0, 1}; + + std::array wei_lengths{G, K, Y, X, C}; + std::array wei_strides{0, 0, 0, 0, 1}; + + std::array out_lengths{G, N, Ho, Wo, K}; + std::array out_strides{0, 0, 0, 0, 1}; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose GNHWC/GKYXC/GNHWK to GNCHW/GKCYX/GNCHW + std::rotate( + rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 3)); + std::rotate( + rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 3)); + std::rotate( + rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 3)); + std::rotate( + rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 3)); + std::rotate( + rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 3)); + std::rotate( + rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 3)); + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * G * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e3f6677666545a616a5664080c7fd20ac8ae4e0 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_conv2d_bwd_weight grouped_conv2d_bwd_weight.cpp) +target_link_libraries(client_grouped_conv2d_bwd_weight PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp b/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ecc8568959555740e88892ff6233ca1a5cf7333 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array input_spatial_lengths{Hi, Wi}; + std::array filter_spatial_lengths{Y, X}; + std::array output_spatial_lengths{Ho, Wo}; + + std::array conv_filter_strides{1, 1}; + std::array conv_filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + ck::index_t split_k = 2; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * G * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/3rdparty/composable_kernel/client_example/12_elementwise_normalization/CMakeLists.txt b/3rdparty/composable_kernel/client_example/12_elementwise_normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ba0e1279a402ae00a6e3aead9de6c12c61e14c9 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/12_elementwise_normalization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp) +target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cf46eda412934f82c6070c28aa8fe8a456778d6 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp" + +using ADataType = ck::half_t; // Input 1 +using BDataType = ck::half_t; // Input 2 +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using AccDataType = float; +using XElementwiseOperation = ck::tensor_operation::element_wise::Add; +using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + bool time_kernel = true; + + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + ck::index_t Stride = N; + + auto mn_size = (M - 1) * Stride + N; + + SimpleDeviceMem a_dev_buf(sizeof(ADataType) * mn_size); + SimpleDeviceMem b_dev_buf(sizeof(BDataType) * mn_size); + SimpleDeviceMem gamma_dev_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_dev_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); + + std::array ab_input = {a_dev_buf.GetDeviceBuffer(), + b_dev_buf.GetDeviceBuffer()}; + std::vector abStride = {Stride, 1}; + std::array, 2> abStrides = {abStride, abStride}; + + using DeviceOp = ck::tensor_operation::device::DeviceElementwiseNormalization< + ck::Tuple, + GammaDataType, + BetaDataType, + AccDataType, + YDataType, + XElementwiseOperation, + YElementwiseOperation, + Rank, + NumReduceDim>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + abStrides, + {0, 1}, // gammaStrides + {0, 1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // reduceDims + 1e-4, + ab_input, + gamma_dev_buf.GetDeviceBuffer(), + beta_dev_buf.GetDeviceBuffer(), + y_dev_buf.GetDeviceBuffer(), + XElementwiseOperation{}, + YElementwiseOperation{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(ADataType) * M * N + sizeof(BDataType) * M * N + + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + abStrides, + {1}, // gammaStrides + {1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // reduceDims + 1e-4, + ab_input, + gamma_dev_buf.GetDeviceBuffer(), + beta_dev_buf.GetDeviceBuffer(), + y_dev_buf.GetDeviceBuffer(), + XElementwiseOperation{}, + YElementwiseOperation{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/13_batchnorm/CMakeLists.txt b/3rdparty/composable_kernel/client_example/13_batchnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..54669678ae6f107d87eb6b979a3da7761d9b25d5 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/13_batchnorm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) +add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp) +target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ef21986a4d9a5f1eb25e21a1073c4cc341da88d --- /dev/null +++ b/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp" + +using XDataType = ck::half_t; +using DxDataType = float; +using DyDataType = float; +using AccDataType = float; +using ScaleDataType = ck::half_t; +using DscaleDbiasDataType = float; +using MeanVarDataType = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 4; +constexpr int NumBatchNormReduceDim = 3; + +const double epsilon = std::numeric_limits::epsilon(); + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array xyLengths{16, 8, 128, 256}; + std::array xyStrides{8 * 128 * 256, 128 * 256, 256, 1}; + std::array scaleBiasMeanVarLengths{256}; + std::array scaleBiasMeanVarStrides{1}; + std::array reduceDims{0, 1, 2}; + + ck::index_t numXYElement = + std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies()); + + ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + 1, + std::multiplies()); + + SimpleDeviceMem x(sizeof(XDataType) * numXYElement); + SimpleDeviceMem dy(sizeof(DyDataType) * numXYElement); + SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem dx(sizeof(DxDataType) * numXYElement); + SimpleDeviceMem dscale(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem dbias(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + dy.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + epsilon, + PassThrough{}, + dx.GetDeviceBuffer(), + dscale.GetDeviceBuffer(), + dbias.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + numXYElement * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) + + numScaleBiasMeanVarElement * + (sizeof(ScaleDataType) + sizeof(DscaleDbiasDataType) * 2 + + sizeof(MeanVarDataType) * 2); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + dy.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + epsilon, + PassThrough{}, + dx.GetDeviceBuffer(), + dscale.GetDeviceBuffer(), + dbias.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..322667a46bacae8d0c681939c3890ef9ff476b0e --- /dev/null +++ b/3rdparty/composable_kernel/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp" + +using XDataType = float; +using YDataType = float; +using AccDataType = float; +using ScaleDataType = AccDataType; +using BiasDataType = AccDataType; +using MeanVarDataType = AccDataType; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 4; +constexpr int NumBatchNormReduceDim = 3; + +const double epsilon = std::numeric_limits::epsilon(); +const double averageFactor = 0.1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array xyLengths{16, 8, 128, 256}; + std::array xyStrides{8 * 128 * 256, 128 * 256, 256, 1}; + std::array scaleBiasMeanVarLengths{256}; + std::array scaleBiasMeanVarStrides{1}; + std::array reduceDims{0, 1, 2}; + + ck::index_t numXYElement = + std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies()); + + ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + 1, + std::multiplies()); + + SimpleDeviceMem x(sizeof(XDataType) * numXYElement); + SimpleDeviceMem y(sizeof(YDataType) * numXYElement); + SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem bias(sizeof(BiasDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + bias.GetDeviceBuffer(), + epsilon, + PassThrough{}, + y.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + averageFactor, + nullptr, + nullptr); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + numXYElement * (sizeof(XDataType) + sizeof(YDataType)) + + numScaleBiasMeanVarElement * (sizeof(ScaleDataType) + sizeof(BiasDataType) + + sizeof(MeanVarDataType) + sizeof(MeanVarDataType)); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + bias.GetDeviceBuffer(), + epsilon, + PassThrough{}, + y.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + averageFactor, + nullptr, + nullptr); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/14_instance_id/CMakeLists.txt b/3rdparty/composable_kernel/client_example/14_instance_id/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..87b2a9a0cbc82a82c9567266178ee431fb9149e7 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/14_instance_id/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp) +target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_operations) diff --git a/3rdparty/composable_kernel/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/3rdparty/composable_kernel/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9cfeee1cfe106e69f83dc0184f3956d6751a2947 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp" + +using XDataType = float; +using YDataType = float; +using AccDataType = float; +using ScaleDataType = AccDataType; +using BiasDataType = AccDataType; +using MeanVarDataType = AccDataType; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 4; +constexpr int NumBatchNormReduceDim = 3; + +const double epsilon = std::numeric_limits::epsilon(); +const double averageFactor = 0.1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// In the actual application, the instance index and name are usually from the perf db +static int instance_index = -1; +static std::string instance_name; + +int main(int argc, char* argv[]) +{ + std::array xyLengths{16, 8, 128, 256}; + std::array xyStrides{8 * 128 * 256, 128 * 256, 256, 1}; + std::array scaleBiasMeanVarLengths{256}; + std::array scaleBiasMeanVarStrides{1}; + std::array reduceDims{0, 1, 2}; + + ck::index_t numXYElement = + std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies()); + + ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + 1, + std::multiplies()); + + SimpleDeviceMem x(sizeof(XDataType) * numXYElement); + SimpleDeviceMem y(sizeof(YDataType) * numXYElement); + SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem bias(sizeof(BiasDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + bool found = false; + int best_op_index = -1; + float best_ave_time = std::numeric_limits::max(); + + // profile device operation instances and save the best performant instance index and instance + // name + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + bias.GetDeviceBuffer(), + epsilon, + PassThrough{}, + y.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + averageFactor, + nullptr, + nullptr); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + if(ave_time < best_ave_time) + { + found = true; + best_op_index = i; + best_ave_time = ave_time; + } + } + } + + if(found) + { + instance_index = best_op_index; + instance_name = op_ptrs[instance_index]->GetTypeIdHashCode(); + }; + + // simulate the execution of the operation when the instance index and name are available + const auto op_ptrs_2 = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(instance_index >= 0 && instance_index < op_ptrs_2.size()) + { + auto& op_ptr = op_ptrs_2[instance_index]; + + if(op_ptr->GetTypeIdHashCode() == instance_name) + { + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + bias.GetDeviceBuffer(), + epsilon, + PassThrough{}, + y.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + averageFactor, + nullptr, + nullptr); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float exec_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + size_t num_bytes = numXYElement * (sizeof(XDataType) + sizeof(YDataType)) + + numScaleBiasMeanVarElement * + (sizeof(ScaleDataType) + sizeof(BiasDataType) + + sizeof(MeanVarDataType) + sizeof(MeanVarDataType)); + + float gb_per_sec = num_bytes / 1.E6 / exec_time; + + std::cout << "Kernel execution time: " << std::setw(10) << exec_time + << " ms, effective data transfer bandwidth: " << gb_per_sec << " GB/s" + << std::endl; + } + }; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/client_example/CMakeLists.txt b/3rdparty/composable_kernel/client_example/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5993d080b8ad1afd098c4b0bb0399132a0fb779 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++17) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +find_package(hip REQUIRED PATHS $ENV{ROCM_PATH}) +message(STATUS "Build with HIP ${hip_VERSION}") + +# add all example subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")) + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/3rdparty/composable_kernel/client_example/README.md b/3rdparty/composable_kernel/client_example/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5c84f1d145d80614f417d7cedd3988fd1cb4eaf4 --- /dev/null +++ b/3rdparty/composable_kernel/client_example/README.md @@ -0,0 +1,18 @@ +## +Client application links to CK library, and therefore CK library needs to be installed before building client applications. + + +## Build +```bash +mkdir -p client_example/build +cd client_example/build +``` + +```bash +cmake -D CMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc -D CMAKE_PREFIX_PATH="${ROCM_PATH};${PATH_TO_CK_INSTALL_DIRECTORY}" .. +``` + +### Build client example +```bash + make -j +``` diff --git a/3rdparty/composable_kernel/cmake/Analyzers.cmake b/3rdparty/composable_kernel/cmake/Analyzers.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1bf1a52c680e6a704dfc4a5da62f728f579e8c18 --- /dev/null +++ b/3rdparty/composable_kernel/cmake/Analyzers.cmake @@ -0,0 +1,34 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +if(NOT TARGET analyze) + add_custom_target(analyze) +endif() + +function(mark_as_analyzer) + add_dependencies(analyze ${ARGN}) +endfunction() + diff --git a/3rdparty/composable_kernel/cmake/ClangTidy.cmake b/3rdparty/composable_kernel/cmake/ClangTidy.cmake new file mode 100644 index 0000000000000000000000000000000000000000..654d4b636446416c2d6daa3b5a6ebb69712440fe --- /dev/null +++ b/3rdparty/composable_kernel/cmake/ClangTidy.cmake @@ -0,0 +1,162 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(Analyzers) + +get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH) +set(ROCM_PATH $ENV{ROCM_PATH}) +find_program(CLANG_TIDY_EXE + NAMES + clang-tidy + clang-tidy-5.0 + clang-tidy-4.0 + clang-tidy-3.9 + clang-tidy-3.8 + clang-tidy-3.7 + clang-tidy-3.6 + clang-tidy-3.5 + HINTS + ${CLANG_TIDY_EXE_HINT} + PATH_SUFFIXES + compiler/bin + PATHS + ${ROCM_PATH}/llvm/bin + ${ROCM_PATH}/hcc + /usr/local/opt/llvm/bin +) + +function(find_clang_tidy_version VAR) + execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT) + separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}") + list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX) + if(VERSION_INDEX GREATER 0) + math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1") + list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION) + set(${VAR} ${VERSION} PARENT_SCOPE) + else() + set(${VAR} "0.0" PARENT_SCOPE) + endif() + +endfunction() + +if( NOT CLANG_TIDY_EXE ) + message( STATUS "Clang tidy not found" ) + set(CLANG_TIDY_VERSION "0.0") +else() + find_clang_tidy_version(CLANG_TIDY_VERSION) + message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}") +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits) +file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR}) + +macro(enable_clang_tidy) + set(options ANALYZE_TEMPORARY_DTORS ALL) + set(oneValueArgs HEADER_FILTER) + set(multiValueArgs CHECKS ERRORS EXTRA_ARGS) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}") + set(CLANG_TIDY_EXTRA_ARGS) + foreach(ARG ${PARSE_EXTRA_ARGS}) + list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}") + endforeach() + + set(CLANG_TIDY_ALL) + if(PARSE_ALL) + set(CLANG_TIDY_ALL ALL) + endif() + + message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}") + + if (${PARSE_ANALYZE_TEMPORARY_DTORS}) + set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_ERRORS_ARG "") + else() + set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_QUIET_ARG "") + else() + set(CLANG_TIDY_QUIET_ARG "-quiet") + endif() + + if(PARSE_HEADER_FILTER) + string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}") + else() + set(CLANG_TIDY_HEADER_FILTER ".*") + endif() + + set(CLANG_TIDY_COMMAND + ${CLANG_TIDY_EXE} + ${CLANG_TIDY_QUIET_ARG} + -p ${CMAKE_BINARY_DIR} + -checks='${CLANG_TIDY_CHECKS}' + ${CLANG_TIDY_ERRORS_ARG} + ${CLANG_TIDY_EXTRA_ARGS} + ${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS} + -header-filter='${CLANG_TIDY_HEADER_FILTER}' + ) + add_custom_target(tidy ${CLANG_TIDY_ALL}) + mark_as_analyzer(tidy) + add_custom_target(tidy-base) + add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR}) + add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR}) + add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir) + add_dependencies(tidy-base tidy-make-fixit-dir) +endmacro() + +function(clang_tidy_check TARGET) + get_target_property(SOURCES ${TARGET} SOURCES) + # TODO: Use generator expressions instead + # COMMAND ${CLANG_TIDY_COMMAND} $ + # COMMAND ${CLANG_TIDY_COMMAND} $, > + foreach(SOURCE ${SOURCES}) + if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) + string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + set(tidy_target tidy-target-${TARGET}-${tidy_file}) + add_custom_target(${tidy_target} + # for some targets clang-tidy not able to get information from .clang-tidy + DEPENDS ${SOURCE} + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." + ) + add_dependencies(${tidy_target} ${TARGET}) + add_dependencies(${tidy_target} tidy-base) + add_dependencies(tidy ${tidy_target}) + endif() + endforeach() +endfunction() + diff --git a/3rdparty/composable_kernel/cmake/CppCheck.cmake b/3rdparty/composable_kernel/cmake/CppCheck.cmake new file mode 100644 index 0000000000000000000000000000000000000000..63e1441f4778bff3838d2f4fe88905b15441c5f2 --- /dev/null +++ b/3rdparty/composable_kernel/cmake/CppCheck.cmake @@ -0,0 +1,130 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +include(CMakeParseArguments) +include(ProcessorCount) +include(Analyzers) + +find_program(CPPCHECK_EXE + NAMES + cppcheck + PATHS + $ENV{ROCM_PATH}/bin +) + +ProcessorCount(CPPCHECK_JOBS) + +set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build) +file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR}) + +macro(enable_cppcheck) + set(options FORCE) + set(oneValueArgs) + set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*") + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}") + set(CPPCHECK_DEFINES) + foreach(DEF ${PARSE_DEFINE}) + set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}") + endforeach() + + set(CPPCHECK_UNDEFINES) + foreach(DEF ${PARSE_UNDEFINE}) + set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}") + endforeach() + + set(CPPCHECK_INCLUDES) + foreach(INC ${PARSE_INCLUDE}) + set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}") + endforeach() + + # set(CPPCHECK_FORCE) + set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json") + if(PARSE_FORCE) + set(CPPCHECK_FORCE --force) + endif() + + set(SOURCES) + set(GLOBS) + foreach(SOURCE ${PARSE_SOURCES}) + get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE) + if(EXISTS ${ABS_SOURCE}) + if(IS_DIRECTORY ${ABS_SOURCE}) + set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h") + else() + set(SOURCES "${SOURCES} ${ABS_SOURCE}") + endif() + else() + set(GLOBS "${GLOBS} ${ABS_SOURCE}") + endif() + endforeach() + + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake " + file(GLOB_RECURSE GSRCS ${GLOBS}) + set(CPPCHECK_COMMAND + ${CPPCHECK_EXE} + -q + # -v + # --report-progress + ${CPPCHECK_FORCE} + --cppcheck-build-dir=${CPPCHECK_BUILD_DIR} + --platform=native + --template=gcc + --error-exitcode=1 + -j ${CPPCHECK_JOBS} + ${CPPCHECK_DEFINES} + ${CPPCHECK_UNDEFINES} + ${CPPCHECK_INCLUDES} + --enable=${CPPCHECK_CHECKS} + --inline-suppr + --suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions + ${SOURCES} \${GSRCS} + ) + string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\") + message(\"\${CPPCHECK_SHOW_COMMAND}\") + execute_process( + COMMAND \${CPPCHECK_COMMAND} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + ) + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR \"Cppcheck failed\") + endif() +") + + add_custom_target(cppcheck + COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "cppcheck: Running cppcheck..." + ) + mark_as_analyzer(cppcheck) +endmacro() + + diff --git a/3rdparty/composable_kernel/cmake/DoxygenDoc.cmake b/3rdparty/composable_kernel/cmake/DoxygenDoc.cmake new file mode 100644 index 0000000000000000000000000000000000000000..2e3669fcdf48fde87c23a84c2493ce26aecda7af --- /dev/null +++ b/3rdparty/composable_kernel/cmake/DoxygenDoc.cmake @@ -0,0 +1,355 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(MainDoc) + +find_program(DOXYGEN_EXECUTABLE NAMES doxygen + PATH_SUFFIXES bin + DOC "Doxygen documentation generator" +) +mark_as_advanced(DOXYGEN_EXECUTABLE) + +find_path(DOT_EXECUTABLE NAMES dot + PATH_SUFFIXES bin + DOC "Graphviz" +) +mark_as_advanced(DOT_EXECUTABLE) + +set(DOXYGEN_ARGS +ABBREVIATE_BRIEF +ALIASES +ALLEXTERNALS +ALLOW_UNICODE_NAMES +ALPHABETICAL_INDEX +ALWAYS_DETAILED_SEC +AUTOLINK_SUPPORT +BINARY_TOC +BRIEF_MEMBER_DESC +BUILTIN_STL_SUPPORT +CALLER_GRAPH +CALL_GRAPH +CASE_SENSE_NAMES +CHM_FILE +CHM_INDEX_ENCODING +CITE_BIB_FILES +CLANG_ASSISTED_PARSING +CLANG_OPTIONS +CLASS_DIAGRAMS +CLASS_GRAPH +COLLABORATION_GRAPH +COLS_IN_ALPHA_INDEX +COMPACT_LATEX +COMPACT_RTF +CPP_CLI_SUPPORT +CREATE_SUBDIRS +DIAFILE_DIRS +DIA_PATH +DIRECTORY_GRAPH +DISABLE_INDEX +DISTRIBUTE_GROUP_DOC +DOCBOOK_OUTPUT +DOCBOOK_PROGRAMLISTING +DOCSET_BUNDLE_ID +DOCSET_FEEDNAME +DOCSET_PUBLISHER_ID +DOCSET_PUBLISHER_NAME +DOTFILE_DIRS +DOT_CLEANUP +DOT_FONTNAME +DOT_FONTPATH +DOT_FONTSIZE +DOT_GRAPH_MAX_NODES +DOT_IMAGE_FORMAT +DOT_MULTI_TARGETS +DOT_NUM_THREADS +# DOT_PATH +DOT_TRANSPARENT +DOXYFILE_ENCODING +ECLIPSE_DOC_ID +ENABLED_SECTIONS +ENABLE_PREPROCESSING +ENUM_VALUES_PER_LINE +EXAMPLE_PATH +EXAMPLE_PATTERNS +EXAMPLE_RECURSIVE +EXCLUDE +EXCLUDE_PATTERNS +EXCLUDE_SYMBOLS +EXCLUDE_SYMLINKS +EXPAND_AS_DEFINED +EXPAND_ONLY_PREDEF +EXTENSION_MAPPING +EXTERNAL_GROUPS +EXTERNAL_PAGES +EXTERNAL_SEARCH +EXTERNAL_SEARCH_ID +EXTRACT_ALL +EXTRACT_ANON_NSPACES +EXTRACT_LOCAL_CLASSES +EXTRACT_LOCAL_METHODS +EXTRACT_PACKAGE +EXTRACT_PRIVATE +EXTRACT_STATIC +EXTRA_PACKAGES +EXTRA_SEARCH_MAPPINGS +EXT_LINKS_IN_WINDOW +FILE_PATTERNS +FILE_VERSION_FILTER +FILTER_PATTERNS +FILTER_SOURCE_FILES +FILTER_SOURCE_PATTERNS +FORCE_LOCAL_INCLUDES +FORMULA_FONTSIZE +FORMULA_TRANSPARENT +FULL_PATH_NAMES +GENERATE_AUTOGEN_DEF +GENERATE_BUGLIST +GENERATE_CHI +GENERATE_DEPRECATEDLIST +GENERATE_DOCBOOK +GENERATE_DOCSET +GENERATE_ECLIPSEHELP +GENERATE_HTML +GENERATE_HTMLHELP +GENERATE_LATEX +GENERATE_LEGEND +GENERATE_MAN +GENERATE_PERLMOD +GENERATE_QHP +GENERATE_RTF +GENERATE_TAGFILE +GENERATE_TESTLIST +GENERATE_TODOLIST +GENERATE_TREEVIEW +GENERATE_XML +GRAPHICAL_HIERARCHY +GROUP_GRAPHS +GROUP_NESTED_COMPOUNDS +# HAVE_DOT +HHC_LOCATION +HIDE_COMPOUND_REFERENCE +HIDE_FRIEND_COMPOUNDS +HIDE_IN_BODY_DOCS +HIDE_SCOPE_NAMES +HIDE_UNDOC_CLASSES +HIDE_UNDOC_MEMBERS +HIDE_UNDOC_RELATIONS +HTML_COLORSTYLE_GAMMA +HTML_COLORSTYLE_HUE +HTML_COLORSTYLE_SAT +HTML_DYNAMIC_SECTIONS +HTML_EXTRA_FILES +HTML_EXTRA_STYLESHEET +HTML_FILE_EXTENSION +HTML_FOOTER +HTML_HEADER +HTML_INDEX_NUM_ENTRIES +HTML_OUTPUT +HTML_STYLESHEET +HTML_TIMESTAMP +IDL_PROPERTY_SUPPORT +IGNORE_PREFIX +IMAGE_PATH +INCLUDED_BY_GRAPH +INCLUDE_FILE_PATTERNS +INCLUDE_GRAPH +INCLUDE_PATH +INHERIT_DOCS +INLINE_GROUPED_CLASSES +INLINE_INFO +INLINE_INHERITED_MEMB +INLINE_SIMPLE_STRUCTS +INLINE_SOURCES +INPUT +INPUT_ENCODING +INPUT_FILTER +INTERACTIVE_SVG +INTERNAL_DOCS +JAVADOC_AUTOBRIEF +LATEX_BATCHMODE +LATEX_BIB_STYLE +LATEX_CMD_NAME +LATEX_EXTRA_FILES +LATEX_EXTRA_STYLESHEET +LATEX_FOOTER +LATEX_HEADER +LATEX_HIDE_INDICES +LATEX_OUTPUT +LATEX_SOURCE_CODE +LATEX_TIMESTAMP +LAYOUT_FILE +LOOKUP_CACHE_SIZE +MACRO_EXPANSION +MAKEINDEX_CMD_NAME +MAN_EXTENSION +MAN_LINKS +MAN_OUTPUT +MAN_SUBDIR +MARKDOWN_SUPPORT +MATHJAX_CODEFILE +MATHJAX_EXTENSIONS +MATHJAX_FORMAT +MATHJAX_RELPATH +MAX_DOT_GRAPH_DEPTH +MAX_INITIALIZER_LINES +MSCFILE_DIRS +MSCGEN_PATH +MULTILINE_CPP_IS_BRIEF +OPTIMIZE_FOR_FORTRAN +OPTIMIZE_OUTPUT_FOR_C +OPTIMIZE_OUTPUT_JAVA +OPTIMIZE_OUTPUT_VHDL +OUTPUT_DIRECTORY +OUTPUT_LANGUAGE +PAPER_TYPE +PDF_HYPERLINKS +PERLMOD_LATEX +PERLMOD_MAKEVAR_PREFIX +PERLMOD_PRETTY +PERL_PATH +PLANTUML_CFG_FILE +PLANTUML_INCLUDE_PATH +PLANTUML_JAR_PATH +PREDEFINED +PROJECT_BRIEF +PROJECT_LOGO +PROJECT_NAME +PROJECT_NUMBER +QCH_FILE +QHG_LOCATION +QHP_CUST_FILTER_ATTRS +QHP_CUST_FILTER_NAME +QHP_NAMESPACE +QHP_SECT_FILTER_ATTRS +QHP_VIRTUAL_FOLDER +QT_AUTOBRIEF +QUIET +RECURSIVE +REFERENCED_BY_RELATION +REFERENCES_LINK_SOURCE +REFERENCES_RELATION +REPEAT_BRIEF +RTF_EXTENSIONS_FILE +RTF_HYPERLINKS +RTF_OUTPUT +RTF_SOURCE_CODE +RTF_STYLESHEET_FILE +SEARCHDATA_FILE +SEARCHENGINE +SEARCHENGINE_URL +SEARCH_INCLUDES +SEPARATE_MEMBER_PAGES +SERVER_BASED_SEARCH +SHORT_NAMES +SHOW_FILES +SHOW_GROUPED_MEMB_INC +SHOW_INCLUDE_FILES +SHOW_NAMESPACES +SHOW_USED_FILES +SIP_SUPPORT +SKIP_FUNCTION_MACROS +SORT_BRIEF_DOCS +SORT_BY_SCOPE_NAME +SORT_GROUP_NAMES +SORT_MEMBERS_CTORS_1ST +SORT_MEMBER_DOCS +SOURCE_BROWSER +SOURCE_TOOLTIPS +STRICT_PROTO_MATCHING +STRIP_CODE_COMMENTS +STRIP_FROM_INC_PATH +STRIP_FROM_PATH +SUBGROUPING +TAB_SIZE +TAGFILES +TCL_SUBST +TEMPLATE_RELATIONS +TOC_EXPAND +TOC_INCLUDE_HEADINGS +TREEVIEW_WIDTH +TYPEDEF_HIDES_STRUCT +UML_LIMIT_NUM_FIELDS +UML_LOOK +USE_HTAGS +USE_MATHJAX +USE_MDFILE_AS_MAINPAGE +USE_PDFLATEX +VERBATIM_HEADERS +WARNINGS +WARN_AS_ERROR +WARN_FORMAT +WARN_IF_DOC_ERROR +WARN_IF_UNDOCUMENTED +WARN_LOGFILE +WARN_NO_PARAMDOC +XML_OUTPUT +XML_PROGRAMLISTING +) + +set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") + +function(add_doxygen_doc) + set(options) + set(oneValueArgs) + set(multiValueArgs DEPENDS ${DOXYGEN_ARGS}) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n") + + foreach(ARG ${DOXYGEN_ARGS}) + if(PARSE_${ARG}) + string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}}) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n") + endif() + endforeach() + + if(PARSE_OUTPUT_DIRECTORY) + if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY}) + file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY}) + endif() + endif() + + if(DOT_EXECUTABLE) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n") + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n") + else() + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n") + endif() + + add_custom_target(doxygen + ${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Building documentation with doxygen" + ) + if(PARSE_OUTPUT_DIRECTORY) + clean_doc_output(${PARSE_OUTPUT_DIRECTORY}) + endif() + mark_as_doc(doxygen) + if(PARSE_DEPENDS) + add_dependencies(doxygen ${PARSE_DEPENDS}) + endif() +endfunction() diff --git a/3rdparty/composable_kernel/cmake/EnableCompilerWarnings.cmake b/3rdparty/composable_kernel/cmake/EnableCompilerWarnings.cmake new file mode 100644 index 0000000000000000000000000000000000000000..78133af0315dd952771ecb481525c14f8f4abff5 --- /dev/null +++ b/3rdparty/composable_kernel/cmake/EnableCompilerWarnings.cmake @@ -0,0 +1,110 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# - Enable warning all for gcc/clang or use /W4 for visual studio + +## Strict warning level +if (MSVC) + # Use the highest warning level for visual studio. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w") + # set(CMAKE_CXX_WARNING_LEVEL 4) + # if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + # else () + # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") + # endif () + + # set(CMAKE_C_WARNING_LEVEL 4) + # if (CMAKE_C_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + # else () + # set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4") + # endif () + +else() + foreach(COMPILER C CXX) + set(CMAKE_COMPILER_WARNINGS) + # use -Wall for gcc and clang + list(APPEND CMAKE_COMPILER_WARNINGS + -Wall + -Wextra + -Wcomment + -Wendif-labels + -Wformat + -Winit-self + -Wreturn-type + -Wsequence-point + # Shadow is broken on gcc when using lambdas + # -Wshadow + -Wswitch + -Wtrigraphs + -Wundef + -Wuninitialized + -Wunreachable-code + -Wunused + + -Wsign-compare + -Wno-extra-semi-stmt + ) + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") + list(APPEND CMAKE_COMPILER_WARNINGS + -Weverything + -Wno-c++98-compat + -Wno-c++98-compat-pedantic + -Wno-conversion + -Wno-double-promotion + -Wno-exit-time-destructors + -Wno-extra-semi + -Wno-float-conversion + -Wno-gnu-anonymous-struct + -Wno-gnu-zero-variadic-macro-arguments + -Wno-missing-prototypes + -Wno-nested-anon-types + -Wno-padded + -Wno-return-std-move-in-c++11 + -Wno-shorten-64-to-32 + -Wno-sign-conversion + -Wno-unknown-warning-option + -Wno-unused-command-line-argument + -Wno-weak-vtables + -Wno-covered-switch-default + ) + else() + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") + # cmake 3.5.2 does not support >=. + if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1") + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-ignored-attributes) + endif() + endif() + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-missing-field-initializers + -Wno-deprecated-declarations + ) + endif() + add_definitions(${CMAKE_COMPILER_WARNINGS}) + endforeach() +endif () diff --git a/3rdparty/composable_kernel/cmake/TargetFlags.cmake b/3rdparty/composable_kernel/cmake/TargetFlags.cmake new file mode 100644 index 0000000000000000000000000000000000000000..4f83fb5d396881a84cd2aeac7f91716966a02ab6 --- /dev/null +++ b/3rdparty/composable_kernel/cmake/TargetFlags.cmake @@ -0,0 +1,50 @@ + +function(get_target_property2 VAR TARGET PROPERTY) + get_target_property(_pflags ${TARGET} ${PROPERTY}) + if(_pflags) + set(${VAR} ${_pflags} PARENT_SCOPE) + else() + set(${VAR} "" PARENT_SCOPE) + endif() +endfunction() + + +macro(append_flags FLAGS TARGET PROPERTY PREFIX) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + else() + string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") + endif() + endforeach() +endmacro() + +macro(append_link_flags FLAGS TARGET PROPERTY) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + elseif(FLAG MATCHES "^-.*") + string(APPEND ${FLAGS} " ${FLAG}") + elseif(EXISTS ${FLAG}) + string(APPEND ${FLAGS} " ${FLAG}") + else() + string(APPEND ${FLAGS} " -l${FLAG}") + endif() + endforeach() +endmacro() + +function(target_flags FLAGS TARGET) + set(_flags) + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") + append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") + append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") + # message("_flags: ${_flags}") + set(${FLAGS} ${_flags} PARENT_SCOPE) +endfunction() diff --git a/3rdparty/composable_kernel/cmake/googletest.cmake b/3rdparty/composable_kernel/cmake/googletest.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3da4877a1fca9ef7e61afdcd0b96c92be425537e --- /dev/null +++ b/3rdparty/composable_kernel/cmake/googletest.cmake @@ -0,0 +1,50 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +message(STATUS "Fetching GoogleTest") + +list(APPEND GTEST_CMAKE_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast + -Wno-deprecated +) +message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") + +FetchContent_Declare( + googletest + GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git +# GIT_REPOSITORY /work/home/zhangshao/installer/googletest + GIT_TAG b85864c64758dec007208e56af933fc3f52044ee +) + +# Will be necessary for windows build +# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_GetProperties(googletest) +if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() + +target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) + +set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/3rdparty/composable_kernel/compile.sh b/3rdparty/composable_kernel/compile.sh new file mode 100644 index 0000000000000000000000000000000000000000..92521ea8ccffb8e0b2c8c043bdff39cd98c283fb --- /dev/null +++ b/3rdparty/composable_kernel/compile.sh @@ -0,0 +1,13 @@ +mkdir build -p +cd build + +cmake \ +-D CMAKE_PREFIX_PATH=${ROCM_PATH} \ +-D CMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +-D GPU_TARGETS="gfx906;gfx926" \ +-D CMAKE_INSTALL_PREFIX=~/composable_kernel/install_ck \ +.. + +cd - diff --git a/3rdparty/composable_kernel/dev-requirements.txt b/3rdparty/composable_kernel/dev-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e7b9f01e1cae3901bd0f1661f39bc383300ebd6 --- /dev/null +++ b/3rdparty/composable_kernel/dev-requirements.txt @@ -0,0 +1,3 @@ +ROCmSoftwarePlatform/rocm-recipes +RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build +danmar/cppcheck@2.9 \ No newline at end of file diff --git a/3rdparty/composable_kernel/doc/image/ck_component.png b/3rdparty/composable_kernel/doc/image/ck_component.png new file mode 100644 index 0000000000000000000000000000000000000000..db892331d77273208f861eb68f1df46d0f2667f4 Binary files /dev/null and b/3rdparty/composable_kernel/doc/image/ck_component.png differ diff --git a/3rdparty/composable_kernel/doc/image/ck_layer.png b/3rdparty/composable_kernel/doc/image/ck_layer.png new file mode 100644 index 0000000000000000000000000000000000000000..117a1b3a0ef890a0a2db9bf23f14f03278d1d965 Binary files /dev/null and b/3rdparty/composable_kernel/doc/image/ck_layer.png differ diff --git a/3rdparty/composable_kernel/doc/markdown/dockerhub.md b/3rdparty/composable_kernel/doc/markdown/dockerhub.md new file mode 100644 index 0000000000000000000000000000000000000000..91b6cb2295cf4a8dd08ead2eb165119c74996084 --- /dev/null +++ b/3rdparty/composable_kernel/doc/markdown/dockerhub.md @@ -0,0 +1,93 @@ +## CK docker hub + +[Docker hub](https://hub.docker.com/r/rocm/composable_kernel) + +## Why do I need this? + +To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images. + +## So what is Composable Kernel? + +Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. + +To get the CK library + +``` +git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git +``` + +run a docker container + +``` +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ +/bin/bash +``` + +and build the CK + +``` +mkdir build && cd build + +# Need to specify target ID, example below is for gfx908 and gfx90a +cmake \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +-D GPU_TARGETS="gfx908;gfx90a" \ +.. +``` + +and + +``` +make -j examples tests +``` + +To run all the test cases including tests and examples run + +``` +make test +``` + +We can also run specific examples or tests like + +``` +./bin/example_gemm_xdl_fp16 +./bin/test_gemm_fp16 +``` + +For more details visit [CK github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel), [CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example), [even more CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example). + +## And what is inside? + +The docker images have everything you need for running CK including: + +* [ROCm](https://www.amd.com/en/graphics/servers-solutions-rocm) +* [CMake](https://cmake.org/) +* [Compiler](https://github.com/RadeonOpenCompute/llvm-project) + +## Which image is right for me? + +Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are: + +* "ck" - made for running Composable Kernel +* "ub20.04" - based on Ubuntu 20.04 +* "rocm5.4" - ROCm platform version 5.4 +* "release" - compiler version is release + +So just pick the right image for your project dependencies and you're all set. + +## DIY starts here + +If you need to customize a docker image or just can't stop tinkering, feel free to adjust the [Dockerfile](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile) for your needs. + +## License + +CK is released under the MIT [license](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/LICENSE). diff --git a/3rdparty/composable_kernel/example/01_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/01_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c403e51ed99d10cef3ecf5bc631f3f20fa48a1ab --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/CMakeLists.txt @@ -0,0 +1,37 @@ +add_custom_target(example_gemm_dl) + +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) + +add_dependencies(example_gemm_dl example_gemm_dl_fp32) +add_dependencies(example_gemm_dl example_gemm_dl_fp16) +add_dependencies(example_gemm_dl example_gemm_dl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_int4) +endif(USE_BITINT_EXTENSION_INT4) + + +add_custom_target(example_gemm_xdl) + +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) + +add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) +add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) +add_dependencies(example_gemm_xdl example_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) + +add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) + +add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) +add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) diff --git a/3rdparty/composable_kernel/example/01_gemm/README.md b/3rdparty/composable_kernel/example/01_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..226783b03b0b58cfe4de15a189f07a8324b3bc1d --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_xdl``` + +## Run ```example_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_gemm_xdl 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +``` diff --git a/3rdparty/composable_kernel/example/01_gemm/common.hpp b/3rdparty/composable_kernel/example/01_gemm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..495a8159623beb22d6002bb2e1667ef459b74139 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/common.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp16.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf585a8c51cbb9b5b0218228e4aa706189598804 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp32.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93f085cdee53667a5b906cfed3b037d57d00bf5f --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int4.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e392c490f29a48da3a6424f46876e220306f2907 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int4.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using CDataType = ck::int4_t; +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelCDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < KernelADataType, KernelBDataType, KernelCDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#define BUILD_INT4_EXAMPLE +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int8.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be9e387718f120fb1ba708d374f0ff9a09fc806d --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_bf16.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9aaae6ade9564aa54bc4ea0c7f2c96aac89ff505 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_bf16.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..488babb7588df3f754e7ba1c56eb739a24ac73b5 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp16.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; +// clang-format on + +// clang-format off +using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using DeviceGemmInstance = DeviceGemmInstance0; + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp64.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99253b743d58707707c9765c18e5933cb09e4220 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_fp64.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" + +using ADataType = double; +using BDataType = double; +using CDataType = double; +using AccDataType = double; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if 0 + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>; +#else + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; +#endif + // clang-format on + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f1283a47b36c94be9645abd9fbd63094293f46e --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int4.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using CDataType = ck::int4_t; +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelCDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#define BUILD_INT4_EXAMPLE +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e67594c5bcbd601ae4747ac7720dab66b605ba2c --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_int8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12a69925977b7eef2e5aa178016a7d433e9c0781 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp" + +#include "ck/library/utility/literals.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +#define USING_SKIP_LDS 1 + +// clang-format off +#if USING_SKIP_LDS +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| BBlock| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| buffer| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| size | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if 0 + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>; +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +#else + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; +#endif + +#else +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>; +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +#endif + // clang-format on + + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +template +std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) +{ + os << "[" << std::endl; + for(size_t x = 0; x < matrix.mDesc.GetLengths()[0]; x++) + { + os << "["; + for(size_t y = 0; y < matrix.mDesc.GetLengths()[1]; y++) + { + os << std::setw(5) << static_cast(matrix(x, y)); + } + os << "]" << std::endl; + } + os << "]"; + return os; +} +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + bool time_kernel = false; + + // GEMM shape +#if 1 + ck::index_t M = 16; + ck::index_t N = 64 * 120; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideC = N; +#else + ck::index_t M = 16; + ck::index_t N = 16; + ck::index_t K = 32; + + ck::index_t StrideA = 8; + ck::index_t StrideB = 8; + ck::index_t StrideC = 16; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + // a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + +#if 0 + { + show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; + show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; + show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl; + show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; + } +#endif + ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/01_gemm/run_gemm_example.inc b/3rdparty/composable_kernel/example/01_gemm/run_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..4e2cedb52adc8c7030167ec7f171d91f0796f234 --- /dev/null +++ b/3rdparty/composable_kernel/example/01_gemm/run_gemm_example.inc @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/3rdparty/composable_kernel/example/02_gemm_bilinear/CMakeLists.txt b/3rdparty/composable_kernel/example/02_gemm_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..10ec0f1a71151668e262efcdbaff7100d2d08dfa --- /dev/null +++ b/3rdparty/composable_kernel/example/02_gemm_bilinear/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/02_gemm_bilinear/README.md b/3rdparty/composable_kernel/example/02_gemm_bilinear/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9eb87e1e3479d72497ec72956b1de649b0ff735f --- /dev/null +++ b/3rdparty/composable_kernel/example/02_gemm_bilinear/README.md @@ -0,0 +1,28 @@ +# Instructions for ```example_gemm_bilinear_xdl_fp16``` + +## Run ```example_gemm_bilinear_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE +#arg11 to 12: alpha, beta +./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 +``` +Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s +error: 0 +max_diff: 0, 558.5, 558.5 +``` diff --git a/3rdparty/composable_kernel/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/3rdparty/composable_kernel/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..917b6b1c3142ec8cf7d1c852ea59494fea1842a9 --- /dev/null +++ b/3rdparty/composable_kernel/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/03_gemm_bias_relu/CMakeLists.txt b/3rdparty/composable_kernel/example/03_gemm_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..35c54abac03094f24187df2503aa02b6812c20f3 --- /dev/null +++ b/3rdparty/composable_kernel/example/03_gemm_bias_relu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/03_gemm_bias_relu/README.md b/3rdparty/composable_kernel/example/03_gemm_bias_relu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f28a9a071c879e92be34f84054661647c31ebb35 --- /dev/null +++ b/3rdparty/composable_kernel/example/03_gemm_bias_relu/README.md @@ -0,0 +1,10 @@ +# Instructions for ```example_gemm_bias_relu_xdl_fp16``` + +## Run ```example_gemm_bias_relu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE +./bin/example_gemm_bias_relu_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 +``` diff --git a/3rdparty/composable_kernel/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/3rdparty/composable_kernel/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aee51d05de58b6386a5dd267d45a7e5c12d98276 --- /dev/null +++ b/3rdparty/composable_kernel/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// C = A * B +// E = Relu(C + D); +struct AddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const + { + const ck::half_t x = c + d; + + e = x > 0 ? x : 0; + } +}; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(EDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c75c5ba51e899e4d0cba6564f52a2abb63854dd8 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,17 @@ +add_custom_target(example_gemm_add_add_fastgelu_xdl) + +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + +add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) +add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) +add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) +if(USE_BITINT_EXTENSION_INT4) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) +add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/README.md b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08a55fb9a37f9f7823ae8882d4d36e9bd96ee6fd --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16``` + +## Run ```example_gemm_add_add_fastgelu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" +./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} +d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> +``` diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/common.hpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f9375e09266c437b6d5ceb9023a06ba0654aed7 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/common.hpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using I4 = ck::int4_t; +#endif +using I8 = int8_t; +using I32 = int32_t; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD0 = std::stoi(argv[9]); + problem_size.StrideD1 = std::stoi(argv[10]); + problem_size.StrideE = std::stoi(argv[11]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e50c14dc2b8a8473c3ee5a16a867eebbd147447 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c7ca41444846df8decdaabea32581f83ba7eff5 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1ef266f23dfb7baebbc9fe7ec988ffcc70cb91b3 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b5bc9879b26181031344cb75d4eab57fbae8a6e --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include "common.hpp" + +using ADataType = I4; +using BDataType = I4; +using AccDataType = I32; +using CShuffleDataType = I32; +using D0DataType = I4; +using D1DataType = I4; +using DsDataType = ck::Tuple; +using EDataType = I4; + +using KernelADataType = I8; +using KernelBDataType = I8; +using KernelD0DataType = I8; +using KernelD1DataType = I8; +using KernelDsDataType = ck::Tuple; +using KernelEDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, KernelDsDataType, KernelEDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#define BUILD_INT4_EXAMPLE +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b236f5e9987e980616951b0b2fd0b59c78ace13d --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using D0DataType = I8; +using D1DataType = I8; +using DsDataType = ck::Tuple; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..f3def33b5670cd7269ad0a521695943674c9ff19 --- /dev/null +++ b/3rdparty/composable_kernel/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -0,0 +1,166 @@ +#pragma once + +bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor< +#ifdef BUILD_INT4_EXAMPLE + KernelEDataType +#else + EDataType +#endif + > + e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + const Tensor d0_m_n_converted(d0_m_n); + const Tensor d1_m_n_converted(d1_m_n); + + a_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_device_buf.ToDevice(b_k_n_converted.mData.data()); + d0_device_buf.ToDevice(d0_m_n_converted.mData.data()); + d1_device_buf.ToDevice(d1_m_n_converted.mData.data()); +#else + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor e_m_n_device_result_converted(e_m_n_device_result); + + return ck::utils::check_err(e_m_n_device_result_converted, e_m_n_host_result); +#else + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_add_add_fastgelu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_add_fastgelu(problem_size, config); +} diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/CMakeLists.txt b/3rdparty/composable_kernel/example/09_convnd_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0a53005b8c2846b6c3baaee40272c9ad0858188 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/CMakeLists.txt @@ -0,0 +1,11 @@ +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + +add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) +add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) +add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) + diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/README.md b/3rdparty/composable_kernel/example/09_convnd_fwd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ab5fee549d44aefd22590d66681127a2816d9f9 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/README.md @@ -0,0 +1,32 @@ +# Instructions for ```example_convnd_fwd_xdl``` + +## Run ```example_convnd_fwd_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: N spatial dimensions (default 2) +#Following arguments (depending on number of spatial dims): +# N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +./bin/example_convnd_fwd_xdl 0 1 100 +``` + +Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) +``` +input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{432, 165888, 4} +arg.b_grid_desc_k0_n_k1_{432, 256, 4} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 100 times... +Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s +``` diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/compile.sh b/3rdparty/composable_kernel/example/09_convnd_fwd/compile.sh new file mode 100644 index 0000000000000000000000000000000000000000..340688647117257d4c65ea55aa178e4962c9cc41 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/compile.sh @@ -0,0 +1,10 @@ +mkdir build -p +cd build + +cmake \ +-D CMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +.. + +cd - diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_common.hpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c594ccdf817a069918b348d7b23c6cc533bde28 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err( + out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_common.hpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..855710b9d9a30ec47c74a6f31d1b03ade688e28e --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_common.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +bool run_grouped_conv_fwd_dl(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + using DDataType = ck::remove_cvref_t>; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + bias.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{-1}); + bias.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(DDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d_g_n_k_wos_lengths{}; + std::array d_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{{d_g_n_k_wos_lengths}}, + std::array, 1>{{d_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_conv with the specified compilation parameters does not " + "support this Conv problem" + << std::endl; + return true; + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + NDimSpatial, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + ck::tensor_operation::element_wise::PassThrough>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = + ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + ck::tensor_operation::element_wise::PassThrough{}); + + ref_invoker.Run(ref_argument); + + // cde_elementwise + out_host.ForEach( + [&](auto&, auto idx) { out_element_op(out_host(idx), out_host(idx), bias(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err( + out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db5a7f0bc3351c9f1670993942a54668c191e56f --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_dl_common.hpp" + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using DsDataType = ck::Tuple; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +// clang-format off +using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +// ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +#include "run_convnd_fwd_dl_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_dl_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..964d784c8592b2e94bc38cc8eab6b7a645d88fc1 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_dl_common.hpp" + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using DsDataType = ck::Tuple; +using OutDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +// clang-format off +using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +// ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +#include "run_convnd_fwd_dl_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_dl_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0cd88f214c8550c018ad9f4d268370dbd816c7f --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_dl_common.hpp" + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using DsDataType = ck::Tuple; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +// clang-format off +using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +// ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +#include "run_convnd_fwd_dl_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_dl_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d55d3154916b0f4ba2fee195b7957904ae3d3139 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = ck::bhalf_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d84afba6426b1abe4b10a3b4da1fefe7c2e9272a --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5acc540cf98b8876caa0d5c0812436dd200deb8 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d697976abd43e5d07c49ec4e0819cbdf15ac02b --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = double; +using WeiDataType = double; +using AccDataType = double; +using CShuffleDataType = double; +using OutDataType = double; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 8, // KPerBlock + 2, // AK1 + 2, // BK1 + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 2, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 1>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99f7f2565c748a950841dffa12a1d343df0051ab --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 16>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc b/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..697ada14ba960f7231db3aa0d5b2482ca38578b8 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_dl_example(int argc, char* argv[]) +{ + print_helper_msg(); + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + std::cout << "ndim_spatial_value: " << ndim_spatial_value << std::endl; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd_dl< + ndim_spatial_value, + InDataType, + WeiDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_example.inc b/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..36a68056f1d7a34b4309a06941f0ff1f477eaba8 --- /dev/null +++ b/3rdparty/composable_kernel/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd< + ndim_spatial_value, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..98941b4db53b022388023e5a0eba10c99b5881be --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -0,0 +1,16 @@ +add_custom_target(example_convnd_fwd_reduce_xdl) + +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + +add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) +add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) +add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) +add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00e370f2968df0409155d5407066e863574613ce --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using BF16 = ck::bhalf_t; +using FP16 = ck::half_t; +using FP32 = float; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using I4 = ck::int4_t; +#endif +using I8 = std::int8_t; +using I32 = std::int32_t; + +template +struct LayoutSetting +{ + using ALayout = ALay; + using BLayout = BLay; + using DELayout = DELay; + using RLayout = RLay; +}; + +template +struct LayoutSettingSelector; + +namespace ctl = ck::tensor_layout::convolution; + +template <> +struct LayoutSettingSelector<1> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<2> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<3> final + : LayoutSetting +{ +}; + +template +using ALayout = typename LayoutSettingSelector::ALayout; + +template +using BLayout = typename LayoutSettingSelector::BLayout; + +template +using DELayout = typename LayoutSettingSelector::DELayout; + +template +using RLayout = typename LayoutSettingSelector::RLayout; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ck::utils::conv::ConvParam& problem_size, + ExecutionConfig& config) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + problem_size = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +inline HostTensorDescriptor +make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) +{ + std::vector dimensions{problem_size.G_, problem_size.N_}; + + ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); + + return HostTensorDescriptor(dimensions); +} + +template +void unpack_host_tensor_descriptor(const HostTensorDescriptor& descriptor, + Lengths& lengths, + Strides& strides) +{ + assert(size(descriptor.GetLengths()) == size(lengths)); + std::copy_n(begin(descriptor.GetLengths()), size(descriptor.GetLengths()), begin(lengths)); + + assert(size(descriptor.GetStrides()) == size(strides)); + std::copy_n(begin(descriptor.GetStrides()), size(descriptor.GetStrides()), begin(strides)); +} diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ff29b4b0ff0f11d6c9d3becbd3aaafdccf3020e --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = FP32; +using CShuffleDataType = FP32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ReduceAccDataType = FP32; +using R0DataType = FP32; +using RsDataType = ck::Tuple; + +#include "run_convnd_fwd_max_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02c19c2b63bf4f9ee663140b414c89f6221cff51 --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = FP16; +using BDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP32; +using DsDataType = ck::Tuple<>; +using EDataType = FP16; +using ReduceAccDataType = FP32; +using R0DataType = FP32; +using RsDataType = ck::Tuple; + +#include "run_convnd_fwd_max_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..679bb5c0c45a1e2e74fc56d21ee6f7f73cae76ad --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = FP32; +using BDataType = FP32; +using AccDataType = FP32; +using CShuffleDataType = FP32; +using DsDataType = ck::Tuple<>; +using EDataType = FP32; +using ReduceAccDataType = FP32; +using R0DataType = FP32; +using RsDataType = ck::Tuple; + +#include "run_convnd_fwd_max_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abdbdaf74d5c0bd2b87da17d315e8712cd896d52 --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#define BUILD_INT4_EXAMPLE + +#include "common.hpp" + +using ADataType = I4; +using BDataType = I4; +using KernelADataType = I8; +using KernelBDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I32; +using ReduceAccDataType = I32; +using R0DataType = I32; +using RsDataType = ck::Tuple; + +#include "run_convnd_fwd_max_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf86afa8e94957c01fc92acb5ed2286fcb52466c --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I32; +using ReduceAccDataType = I32; +using R0DataType = I32; +using RsDataType = ck::Tuple; + +#include "run_convnd_fwd_max_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..b3a3891781766acd824bc1974b88f57fdd85b711 --- /dev/null +++ b/3rdparty/composable_kernel/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; + +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +template +using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle +//######| NDimSpatial| ALayout| BLayout| DELayout| RLayout| AData| BData| AccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| Conv| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| +//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Fwd|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| +//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| Specialization| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | +#ifdef BUILD_INT4_EXAMPLE + < NDimSpatial, ALayout, BLayout, DELayout, RLayout, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; +#else + < NDimSpatial, ALayout, BLayout, DELayout, RLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; +#endif + +template +using HostInstance = ck::tensor_operation::host::ReferenceConvFwd + ; +// clang-format on + +template +bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, + const ExecutionConfig& config) +{ + static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial"); + +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + const auto conv_input_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed>( + problem_size); + + const auto conv_weight_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed>( + problem_size); + + const auto conv_output_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed>( + problem_size); + + const auto r0_desc = make_r0_host_tensor_descriptor(problem_size); + + Tensor conv_input(conv_input_g_n_c_wis_desc); + Tensor conv_weight(conv_weight_g_k_c_xs_desc); + Tensor conv_output_device(conv_output_g_n_k_wos_desc); + Tensor r0_device(r0_desc); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight); + break; + default: + ck::utils::FillUniformDistribution{-5, 5}(conv_input); + ck::utils::FillUniformDistribution{-5, 5}(conv_weight); + } + + DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); + DeviceMem conv_weight_device_buf(sizeof(BDataType) * conv_weight.mDesc.GetElementSpaceSize()); + DeviceMem conv_output_device_buf(sizeof(EDataType) * + conv_output_device.mDesc.GetElementSpaceSize()); + DeviceMem r0_device_buf(sizeof(R0DataType) * r0_device.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor conv_input_converted(conv_input); + const Tensor conv_weight_converted(conv_weight); + + conv_input_device_buf.ToDevice(conv_input_converted.mData.data()); + conv_weight_device_buf.ToDevice(conv_weight_converted.mData.data()); +#else + conv_input_device_buf.ToDevice(conv_input.mData.data()); + conv_weight_device_buf.ToDevice(conv_weight.mData.data()); +#endif + + std::array conv_input_g_n_c_wis_lengths{}, + conv_input_g_n_c_wis_strides{}; + std::array conv_weight_g_k_c_xs_lengths{}, + conv_weight_g_k_c_xs_strides{}; + std::array conv_output_g_n_k_wos_lengths{}, + conv_output_g_n_k_wos_strides{}; + std::array r0_lengths{}, r0_strides{}; + std::array conv_filter_strides{}, conv_filter_dilations{}; + std::array input_left_pads{}, input_right_pads{}; + + unpack_host_tensor_descriptor( + conv_input_g_n_c_wis_desc, conv_input_g_n_c_wis_lengths, conv_input_g_n_c_wis_strides); + unpack_host_tensor_descriptor( + conv_weight_g_k_c_xs_desc, conv_weight_g_k_c_xs_lengths, conv_weight_g_k_c_xs_strides); + unpack_host_tensor_descriptor( + conv_output_g_n_k_wos_desc, conv_output_g_n_k_wos_lengths, conv_output_g_n_k_wos_strides); + unpack_host_tensor_descriptor(r0_desc, r0_lengths, r0_strides); + + ck::ranges::copy(problem_size.conv_filter_strides_, begin(conv_filter_strides)); + ck::ranges::copy(problem_size.conv_filter_dilations_, begin(conv_filter_dilations)); + ck::ranges::copy(problem_size.input_left_pads_, begin(input_left_pads)); + ck::ranges::copy(problem_size.input_right_pads_, begin(input_right_pads)); + + // run Conv + Reduction on device + auto conv = DeviceInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(conv_input_device_buf.GetDeviceBuffer(), + conv_weight_device_buf.GetDeviceBuffer(), + std::array{}, + conv_output_device_buf.GetDeviceBuffer(), + {r0_device_buf.GetDeviceBuffer()}, + conv_input_g_n_c_wis_lengths, + conv_input_g_n_c_wis_strides, + conv_weight_g_k_c_xs_lengths, + conv_weight_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + conv_output_g_n_k_wos_lengths, + conv_output_g_n_k_wos_strides, + r0_lengths, + r0_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + AElementOp{}, + BElementOp{}, + CDEElementOp{}, + QsElementOp{}, + RsElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return false; + } + + const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + const std::size_t flop = problem_size.GetFlops(); + const std::size_t num_btype = problem_size.GetByte(); + + const float tflops = static_cast(flop) / 1.E9 / avg_time; + const float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor conv_output_host(conv_output_g_n_k_wos_desc); + + // run Conv + Reduction on host + auto ref_conv = HostInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(conv_input, + conv_weight, + conv_output_host, + problem_size.conv_filter_strides_, + problem_size.conv_filter_dilations_, + problem_size.input_left_pads_, + problem_size.input_right_pads_, + AElementOp{}, + BElementOp{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + Tensor r0_host(r0_device.mDesc); + + auto reduce0_op = RsThreadReduceOp{}[ck::Number<0>{}]; + + auto& output_dims = conv_output_g_n_k_wos_desc.GetLengths(); + + if constexpr(NDimSpatial == 1) + { + for(std::size_t g = 0; g < output_dims[0]; ++g) + { + for(std::size_t n = 0; n < output_dims[1]; ++n) + { + for(std::size_t w = 0; w < output_dims[3]; ++w) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + for(std::size_t k = 0; k < output_dims[2]; ++k) + { + + auto e_val = + ck::type_convert(conv_output_host(g, n, k, w)); + reduce0_op(reduce0_acc, e_val); + } + r0_host(g, n, w) = ck::type_convert(reduce0_acc); + } + } + } + } + else if constexpr(NDimSpatial == 2) + { + for(std::size_t g = 0; g < output_dims[0]; ++g) + { + for(std::size_t n = 0; n < output_dims[1]; ++n) + { + for(std::size_t h = 0; h < output_dims[3]; ++h) + { + for(std::size_t w = 0; w < output_dims[4]; ++w) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + for(std::size_t k = 0; k < output_dims[2]; ++k) + { + + auto e_val = ck::type_convert( + conv_output_host(g, n, k, h, w)); + reduce0_op(reduce0_acc, e_val); + } + r0_host(g, n, h, w) = ck::type_convert(reduce0_acc); + } + } + } + } + } + else if constexpr(NDimSpatial == 3) + { + for(std::size_t g = 0; g < output_dims[0]; ++g) + { + for(std::size_t n = 0; n < output_dims[1]; ++n) + { + for(std::size_t d = 0; d < output_dims[3]; ++d) + { + for(std::size_t h = 0; h < output_dims[4]; ++h) + { + for(std::size_t w = 0; w < output_dims[5]; ++w) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + for(std::size_t k = 0; k < output_dims[2]; ++k) + { + + auto e_val = ck::type_convert( + conv_output_host(g, n, k, d, h, w)); + reduce0_op(reduce0_acc, e_val); + } + r0_host(g, n, d, h, w) = ck::type_convert(reduce0_acc); + } + } + } + } + } + } + + conv_output_device_buf.FromDevice(conv_output_device.mData.data()); + r0_device_buf.FromDevice(r0_device.mData.data()); + + return ck::utils::check_err(conv_output_device, + conv_output_host, + "Error: incorrect results! (Matrix E)", + 1e-5f, + 1e-4f) && + ck::utils::check_err( + r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); + } + + return true; +} + +bool run_convnd_fwd_max_example(int argc, char* argv[]) +{ + ck::utils::conv::ConvParam problem_size{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + ExecutionConfig config; + + if(!parse_cmd_args(argc, argv, problem_size, config)) + { + return false; + } + + switch(problem_size.num_dim_spatial_) + { + case 1: return run_convnd_fwd_max<1>(problem_size, config); + case 2: return run_convnd_fwd_max<2>(problem_size, config); + case 3: return run_convnd_fwd_max<3>(problem_size, config); + } + + return false; +} diff --git a/3rdparty/composable_kernel/example/12_reduce/CMakeLists.txt b/3rdparty/composable_kernel/example/12_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e58ed933807d66a89967b6e278e084ccbe77d4d --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) +add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp) +add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) diff --git a/3rdparty/composable_kernel/example/12_reduce/README.md b/3rdparty/composable_kernel/example/12_reduce/README.md new file mode 100644 index 0000000000000000000000000000000000000000..76d28527bb8aa521611e54d172b427aabf9fb998 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/README.md @@ -0,0 +1,62 @@ +# Instructions for ```example_reduce_blockwise``` + +## Run ```example_reduce_blockwise``` +```bash +# -D : input 3d/4d/5d tensor lengths +# -R : reduce dimension ids +# -v : verification (0=no, 1=yes) +#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 0 2 1 +``` + +Result +``` +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 0 2 1 +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> +``` + +## Run ```example_reduce_multiblock_atomic_add``` +```bash +# -D : input 3d/4d/5d tensor lengths +# -R : reduce dimension ids +# -v : verification (0=no, 1=yes) +#arg1: data type (0: fp32, 1: fp64) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_multiblock_atomic_add -D 16,64,32,960 -v 1 0 2 0 +``` + +Result +``` +./bin/example_reduce_multiblock_atomic_add -D 16,64,32,960 -v 1 0 2 0 +Perf: 0 ms, inf GB/s, DeviceReduceMultiBlock<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> +echo $? +0 +``` + +# Instructions for ```example_reduce_blockwise_two_call``` + +## Run ```example_reduce_blockwise_two_call``` +```bash +#arg1: verification (0=no, 1=yes( +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise_two_call 1 2 1 +``` + +Result +``` +./bin/example_reduce_blockwise_two_call 1 2 1 +launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> +``` diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise.cpp b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7ee9990c1941a73b632f3cc1d32b14a00897cc2 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/utility/reduction_enums.hpp" +#include "reduce_blockwise_impl.hpp" +#include "reduce_example_common.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {16, 64, 32, 960}; + std::vector reduceDims = {0, 1, 2}; + std::vector scales = {1.0f, 0.0f}; + + bool do_verification = true; + int data_type = 1; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)" + << std::endl; + std::cout << "Arg2 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg3 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 3 > argc) + { + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + }; + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +template +bool reduce_blockwise_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + float alpha, + float beta) +{ + bool matched = false; + int result = 0; + + const auto tuple_object = reduce_shape_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using ShapeType = remove_cvref_t(tuple_object))>; + + if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size()) + return; + + std::array arrReduceDims; + + ck::ranges::copy(reduceDims, arrReduceDims.begin()); + + result = reduce_blockwise_impl( + do_verification, init_method, time_kernel, inLengths, arrReduceDims, alpha, beta); + + matched = true; + }); + + return (result == 0) ? true : false; +}; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + SimpleAppArgs arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 1) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 3) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 5) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 6) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + else if(arg.data_type == 7) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + + pass = pass && reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } +#endif + } + else + { + // for testing half_t + pass = + pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing float + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing double + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing bhalf_t + pass = pass && + reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing int8_t + pass = + pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + // for testing int4_t using AVG operation + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing int4_t using MAX operation + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); +#endif + // for testing 3D input + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 960}, {0, 1}, 1.0f, 0.0f); + + // for testing 5D input + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 2, 960}, {0, 1, 2, 3}, 1.0f, 0.0f); + }; + + return (pass ? 0 : 1); +}; diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_impl.hpp b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7bafd2d2bbfbc2a086fe4e09d005bde5cd5db5cf --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_impl.hpp @@ -0,0 +1,338 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_reduction.hpp" + +#include "reduce_example_common.hpp" + +template +int reduce_blockwise_impl(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::array& reduceDims, + float alpha, + float beta) + +{ + using namespace ck; + using namespace ck::tensor_operation::device; + + constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + constexpr bool invalid_reduce_1 = OutputIndex && !op_support_indices; + + // 1) If InOutDataType is half_t, must use half_t as AccDataType for indexable reduction + // operations 2) If InOutDataType is half_t, must use float as AccDataType for non-indexable + // reduction operations + constexpr bool invalid_reduce_2 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InOutDataType is float, must use float as AccDataType for indexable reduction + // operations + constexpr bool invalid_reduce_3 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // 1) If InOutDataType is int8_t or int4_t, must use int8_t as AccDataType for indexable + // reduction operations 2) If InOutDataType is int8_t or int4_t, must use int32_t as AccDataType + // for non-indexable reduction operations + constexpr bool invalid_reduce_4 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce_4_2 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); +#endif + + // 1) If InOutDataType is int8_t or int4_t, the supported operation must be either indexable + // operations or ADD/AVG + constexpr bool invalid_reduce_5 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce_5_2 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); +#endif + + // 1) If InOutDataType is bhalf_t, must use float as AccDataType for all reduction operations + constexpr bool invalid_reduce_6 = + std::is_same::value && !std::is_same::value; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce = + (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 || + invalid_reduce_5 || invalid_reduce_6 || invalid_reduce_4_2 || invalid_reduce_5_2); +#else + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || + invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); +#endif + + if constexpr(invalid_reduce) + { + std::cerr << "The reduction setting is invalid, exiting!" << std::endl; + return (-1); + }; + + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + using InOutDataTypeInDevice = typename std:: + conditional::value, int8_t, InOutDataType>::type; +#else + using InOutDataTypeInDevice = InOutDataType; +#endif + + using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; // OutDstVectorSize + + Tensor in(inLengths); + + std::vector outLengths; + + auto invariantDims = get_invariant_dims(reduceDims); + + if(invariantDims.empty()) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize()); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if(std::is_same::value) + { + std::vector tmp_buf(in.mData.size()); + + std::copy_n(in.mData.data(), in.mData.size(), tmp_buf.data()); + in_dev.ToDevice(tmp_buf.data()); + } + else +#endif + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + { +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if(std::is_same::value) + { + std::vector tmp_buf(in.mData.size()); + + std::copy_n(out.mData.data(), out.mData.size(), tmp_buf.data()); + out_dev.ToDevice(tmp_buf.data()); + } + else +#endif + out_dev.ToDevice(out.mData.data()); + }; + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; + + DeviceMem out_index_dev(indicesSizeInBytes); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); + }; + + std::array arrInLengths; + std::array arrInStrides; + std::array arrOutLengths; + std::array arrOutStrides; + + ck::ranges::copy(inLengths, arrInLengths.begin()); + ck::ranges::copy(inStrides, arrInStrides.begin()); + ck::ranges::copy(outLengths, arrOutLengths.begin()); + ck::ranges::copy(outStrides, arrOutStrides.begin()); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, + arrInStrides, + arrOutLengths, + arrOutStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_index_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); + + if(!reduce.IsSupportedArgument(argument_ptr.get())) + { + std::cerr + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + + return (-2); + }; + + std::string reduce_name = reduce.GetTypeString(); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(do_verification) + { +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if(std::is_same::value) + { + std::vector tmp_buf(out.mData.size()); + + out_dev.FromDevice(tmp_buf.data()); + + std::copy_n(tmp_buf.data(), out.mData.size(), out.mData.data()); + } + else +#endif + out_dev.FromDevice(out.mData.data()); + + pass = pass && ck::utils::check_err(out, out_ref); + + if(OutputIndex) + { + out_index_dev.FromDevice(out_indices.mData.data()); + pass = pass && ck::utils::check_err(out_indices, out_indices_ref); + }; + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_two_call.cpp b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_two_call.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39821f240ad5f83c49073a5f188d151fdb2d6473 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_blockwise_two_call.cpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_reduction.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InOutDataType = ck::half_t; +using InOutDataType = ck::half_t; +using AccDataType = float; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using PassThroughOp = tensor_operation::element_wise::PassThrough; + +using DeviceReduceInstance_1 = DeviceReduceMultiBlock; + +using DeviceReduceInstance_2 = DeviceReduceMultiBlock; + +static bool do_verify; +static int init_method; +static float alpha; +static float beta; +static bool time_kernel; + +int main(int argc, char* argv[]) +{ + // used by the device reduction + const std::array reduceDims_1 = {4}; + // const std::array invariantDims_1 = {0, 1, 2, 3}; + + const std::array reduceDims_2 = {3}; + // const std::array invariantDims_2 = {0, 1, 2}; + + // used by the host reduction + const std::array reduceDims = {3, 4}; + const std::array invariantDims = {0, 1, 2}; + + const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + + // input lengths of the second reduction, which is also the output lengths of the first + // reduction + const std::vector inLengths_2 = {64, 320, 80, 4}; + + const std::vector outLengths = {64, 320, 80}; + + if(argc == 1) + { + do_verify = true; + init_method = 2; + time_kernel = true; + } + else if(argc == 4) + { + do_verify = static_cast(argv[1]); + init_method = atoi(argv[2]); + time_kernel = static_cast(atoi(argv[3])); + } + else + { + std::ostringstream ostr; + + ostr << "Wrong parameter! " << std::endl + << "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl; + + throw std::runtime_error(ostr.str()); + }; + + alpha = 1.0f; + beta = 0.0f; + + Tensor in_1(inLengths_1); + + Tensor out_ref(outLengths); + Tensor in_2(inLengths_2); // also the output tensor of the first reduction + Tensor out(outLengths); + + auto inStrides_1 = in_1.mDesc.GetStrides(); + auto inStrides_2 = in_2.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verify) + { + switch(init_method) + { + case 0: break; + case 1: + in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpaceSize()); + DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpaceSize()); + + in_1_dev.ToDevice(in_1.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(do_verify) + { + ReductionHost + hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in_1.mData.data(), + beta, + out_ref.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + }; + + std::array arrInLengths_1; + std::array arrInStrides_1; + std::array arrInLengths_2; + std::array arrInStrides_2; + std::array arrOutLengths; + std::array arrOutStrides; + + ck::ranges::copy(inLengths_1, arrInLengths_1.begin()); + ck::ranges::copy(inStrides_1, arrInStrides_1.begin()); + ck::ranges::copy(inLengths_2, arrInLengths_2.begin()); + ck::ranges::copy(inStrides_2, arrInStrides_2.begin()); + ck::ranges::copy(outLengths, arrOutLengths.begin()); + ck::ranges::copy(outStrides, arrOutStrides.begin()); + + auto reduce_1 = DeviceReduceInstance_1{}; + + auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1, + arrInStrides_1, + arrInLengths_2, + arrInStrides_2, + reduceDims_1, + 1.0f, + 0.0f, + in_1_dev.GetDeviceBuffer(), + nullptr, + in_2_dev.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThroughOp{}); + + if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_1 = reduce_1.MakeInvokerPointer(); + + auto reduce_2 = DeviceReduceInstance_2{}; + + auto argument_ptr_2 = reduce_2.MakeArgumentPointer(arrInLengths_2, + arrInStrides_2, + arrOutLengths, + arrOutStrides, + reduceDims_2, + alpha, + beta, + in_2_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + PassThroughOp{}, + acc_elementwise_op); + + if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_2 = reduce_2.MakeInvokerPointer(); + + float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel}); + float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2); + + std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, " + << reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verify) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out, out_ref); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_example_common.hpp b/3rdparty/composable_kernel/example/12_reduce/reduce_example_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..05f0a0edb25ba520d475522bd2963bc050ccc010 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_example_common.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +template +static inline std::array +get_invariant_dims(const std::array& reduceDims) +{ + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + std::array invariantDims; + + // collect invariant dimensions + int dim = 0; + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + invariantDims[dim] = i; + dim++; + }; + + return invariantDims; +}; + +template +struct ReduceShape +{ + static constexpr ck::index_t Rank_ = Rank; + static constexpr ck::index_t NumReduceDim_ = NumReduceDim; +}; + +using reduce_shape_instances = std::tuple, + ReduceShape<3, 2>, + ReduceShape<4, 1>, + ReduceShape<4, 2>, + ReduceShape<4, 3>, + ReduceShape<5, 1>, + ReduceShape<5, 2>, + ReduceShape<5, 3>, + ReduceShape<5, 4>>; diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add.cpp b/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4d63a3add8b0d8016b11b8834a1385dcca5beb5 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/utility/reduction_enums.hpp" +#include "reduce_multiblock_atomic_add_impl.hpp" +#include "reduce_example_common.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {16, 64, 32, 960}; + std::vector reduceDims = {0, 1, 2}; + std::vector scales = {1.0f, 0.0f}; + + bool do_verification = true; + int data_type = 1; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1: data type (0: fp32, 1: fp64)" << std::endl; + std::cout << "Arg2 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg3 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 3 > argc) + { + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + }; + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +template +bool reduce_multiblock_atomic_add_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + float alpha, + float beta) +{ + bool matched = false; + int result = 0; + + const auto tuple_object = reduce_shape_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using ShapeType = remove_cvref_t(tuple_object))>; + + if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size()) + return; + + std::array a_reduceDims; + + ck::ranges::copy(reduceDims, a_reduceDims.begin()); + + result = reduce_multiblock_atomic_add_impl( + do_verification, init_method, time_kernel, inLengths, a_reduceDims, alpha, beta); + + matched = true; + }); + + return (result == 0) ? true : false; +}; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG; +constexpr bool PropagateNan = true; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + SimpleAppArgs arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + pass = reduce_multiblock_atomic_add_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 1) + { + pass = reduce_multiblock_atomic_add_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + } + else + { + // for testing float + pass = pass && reduce_multiblock_atomic_add_test( + true, 2, false, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing double + pass = pass && reduce_multiblock_atomic_add_test( + true, 2, false, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing 3D input + pass = pass && reduce_multiblock_atomic_add_test( + true, 2, false, {16, 64, 960}, {0, 1}, 1.0f, 0.0f); + + // for testing 5D input + pass = pass && reduce_multiblock_atomic_add_test( + true, 2, false, {16, 64, 32, 2, 960}, {0, 1, 2, 3}, 1.0f, 0.0f); + }; + + return (pass ? 0 : 1); +}; diff --git a/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp b/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94867aee41ec07520c81e109e8a5c7c298c0d029 --- /dev/null +++ b/3rdparty/composable_kernel/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_reduction.hpp" + +#include "reduce_example_common.hpp" + +template +int reduce_multiblock_atomic_add_impl(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::array& reduceDims, + float alpha, + float beta) + +{ + using namespace ck; + using namespace ck::tensor_operation::device; + + constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + constexpr bool op_support_atomic_add = + (ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG); + + constexpr bool invalid_reduce_1 = !op_support_atomic_add; + constexpr bool invalid_reduce_2 = + !(std::is_same::value || std::is_same::value); + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2); + + if(invalid_reduce) + { + std::cerr << "The reduction setting is invalid, exiting!" << std::endl; + return (-1); + }; + + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; + + Tensor in(inLengths); + + std::vector outLengths; + + auto invariantDims = get_invariant_dims(reduceDims); + + if(invariantDims.empty()) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InOutDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpaceSize()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + }; + + std::array arrInLengths; + std::array arrInStrides; + std::array arrOutLengths; + std::array arrOutStrides; + + ck::ranges::copy(inLengths, arrInLengths.begin()); + ck::ranges::copy(inStrides, arrInStrides.begin()); + ck::ranges::copy(outLengths, arrOutLengths.begin()); + ck::ranges::copy(outStrides, arrOutStrides.begin()); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, + arrInStrides, + arrOutLengths, + arrOutStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + + if(!reduce.IsSupportedArgument(argument_ptr.get())) + { + std::cerr + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + + return (-2); + }; + + std::string reduce_name = reduce.GetTypeString(); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(do_verification) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out, out_ref); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/13_pool2d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/example/13_pool2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..db09c03321e5ac0f3408dcffbc940b6c9c0879c5 --- /dev/null +++ b/3rdparty/composable_kernel/example/13_pool2d_fwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) + diff --git a/3rdparty/composable_kernel/example/13_pool2d_fwd/README.md b/3rdparty/composable_kernel/example/13_pool2d_fwd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9b017734e92b8fa01b193b9f2c25b0a5ebf4da43 --- /dev/null +++ b/3rdparty/composable_kernel/example/13_pool2d_fwd/README.md @@ -0,0 +1,41 @@ +# Instructions for ```example_pool2d_fwd``` Examples + +## Run ```example_pool2d_fwd_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd_fp16 1 1 1 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s +``` + +## Run ```example_pool2d_fwd_fp32``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd_fp32 1 1 1 +``` + + +Result +``` +./bin/example_pool2d_fwd_fp32 1 1 1 +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.01823 ms, 0.563045 TFlops, 611.8 GB/s +``` diff --git a/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b83cb6a96f0b10f0aeff5e57d386f2ac1c3fa4e4 --- /dev/null +++ b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +static void pool_host_verify(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::array& window_spatial_lengths, + const std::array& window_strides, + const std::array& in_left_pads, + const std::array& /*in_right_pads*/) +{ + const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + using ReduceOperation = typename ck::reduce_binary_operator::opType; + + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); + + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); + + if constexpr(!OutputIndex) + { + using Accumulation = + ck::detail::AccumulateWithNanCheck; + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && + wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + IndexDataType currIndex = y * window_spatial_lengths[1] + x; + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + out_indices(n, c, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + }; +} + +template +bool pool_test(bool do_verification, + int init_method, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + ReduceOpId, + OutputIndex, + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 4>; // InSrcOutDstVectorSize + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::array window_spatial_lengths{{Y, X}}; + const std::array window_strides{{window_stride_h, window_stride_w}}; + const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_host( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_device( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); break; + case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + N, + C, + std::array{{Hi, Wi}}, + std::array{{Y, X}}, + std::array{{Ho, Wo}}, + window_strides, + input_left_pads, + input_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + pool_host_verify(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && + ck::utils::check_err(out_indices_n_c_ho_wo_device, out_indices_n_c_ho_wo_host); + }; + } + + return (pass); +}; diff --git a/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..659f3251dcf9f8d7661cf44169458937f0d15088 --- /dev/null +++ b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f47c7ff15147327603b9eeba815e44465313762c --- /dev/null +++ b/3rdparty/composable_kernel/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/14_gemm_quantization/CMakeLists.txt b/3rdparty/composable_kernel/example/14_gemm_quantization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca09c48c10cd404502dcdafc75cc013af90e22ed --- /dev/null +++ b/3rdparty/composable_kernel/example/14_gemm_quantization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) \ No newline at end of file diff --git a/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5f4e6f62c3eb8e58285e81446a5e1332c37c533 --- /dev/null +++ b/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using CDEElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using BiasDataType = I32; +using DsDataType = ck::Tuple; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using BiasLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + PassThrough, // AElementwiseOperation, + PassThrough, // BElementwiseOperation, + CDEElementOp, // CDEElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // NumGemmKPrefetchStage, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 64, // KPerBlock, + 16, // AK1, + 16, // BK1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MXdlPerWave, + 2, // NXdlPerWave, + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideBias = 0; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1_uz})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1_uz, stride})); + } + }; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor e_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "bias_n: " << bias_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + bias_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideBias}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor{M, N}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), bias_n(n)); + } + } + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2371737382447d666d55633fbf3ac976c3017cf7 --- /dev/null +++ b/3rdparty/composable_kernel/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + PassThrough, // AElementwiseOperation, + PassThrough, // BElementwiseOperation, + CDEElementOp, // CDEElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // NumGemmKPrefetchStage, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 64, // KPerBlock, + 16, // AK1, + 16, // BK1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MXdlPerWave, + 2, // NXdlPerWave, + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1_uz})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1_uz, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/15_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..67f6160873584e0fe31e6e60f675c695b33c8460 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/CMakeLists.txt @@ -0,0 +1,17 @@ +add_custom_target(example_grouped_gemm_xdl) + +add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) +add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) + +add_dependencies(example_grouped_gemm_xdl + example_grouped_gemm_xdl_fp32 + example_grouped_gemm_xdl_fp16 + example_grouped_gemm_xdl_bfp16 + example_grouped_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) +endif() diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/README.md b/3rdparty/composable_kernel/example/15_grouped_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c83b23e08cc7b923ce22316e0c882886d7437ef3 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/README.md @@ -0,0 +1,25 @@ +# Instructions for ```example_grouped_gemm_xdl``` + +## Run ```example_grouped_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_grouped_gemm_xdl_fp16 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} +gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} +gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} +gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} +group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} +group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} +group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} +group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} +launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> +``` diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05d572a1f532446bdc0ec43e59b1eeadfe4bd054 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f78dafa8977b1b57592500138dd8f034b294efd --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd93bb5f87d49a96da8973c3f862b8b48bdeb180 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..faf41bbf0bba8e06fee9982a2f158706401ca45d --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using DsDataType = ck::Tuple<>; +using EDataType = ck::int4_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelEDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off + < ALayout, //ALayout + BLayout, //BLayout + DsLayout, //DsLayout + ELayout, //ELayout + KernelADataType, //ADataType + KernelBDataType, //BDataType + AccDataType, //AccDataType + CShuffleDataType, //CShuffleDataType + DsDataType, //DsDataType + KernelEDataType, //EDataType + AElementOp, //AElementwiseOperation + BElementOp, //BElementwiseOperation + CDEElementOp, //CDEElementwiseOperation + GemmDefault, //GEMMSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl + 16>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cb09778c521c9dfaed38c39ee5cb7fcff7bbc26 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; +using DsDataType = ck::Tuple<>; +using EDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/15_grouped_gemm/run_grouped_gemm_example.inc b/3rdparty/composable_kernel/example/15_grouped_gemm/run_grouped_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..324e1772805ac2f1b9b84872d3b8baec3c39c6b5 --- /dev/null +++ b/3rdparty/composable_kernel/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -0,0 +1,265 @@ +#pragma once + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); + static_assert(sizeof(EDataType) == sizeof(KernelEDataType)); +#endif + int group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_a, p_b; + std::vector p_c; + + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = problem_size.Ms[i]; + int N = problem_size.Ns[i]; + int K = problem_size.Ks[i]; + + int stride_A = problem_size.stride_As[i]; + int stride_B = problem_size.stride_Bs[i]; + int stride_C = problem_size.stride_Cs[i]; + + gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; +#ifdef BUILD_INT4_EXAMPLE + std::vector> c_device_tensors; +#else + std::vector> c_device_tensors; +#endif + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); +#ifdef BUILD_INT4_EXAMPLE + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); +#else + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); +#endif + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_converted(a_tensors[i]); + const Tensor b_converted(b_tensors[i]); + + a_tensors_device[i]->ToDevice(a_converted.mData.data()); + b_tensors_device[i]->ToDevice(b_converted.mData.data()); +#else + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); +#endif + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor c_device_result_converted(c_device_tensors[i]); + pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); + +#else + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); +#endif + } + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_grouped_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return run_grouped_gemm(problem_size, config); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..226656a736637beea41758bf59dd1bd27ed8189f --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -0,0 +1,40 @@ +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + +add_dependencies(example_gemm_reduce_xdl_max + example_gemm_max_xdl_bf16 + example_gemm_max_xdl_fp16 + example_gemm_max_xdl_fp32 + example_gemm_max_xdl_int8) + +add_dependencies(example_gemm_reduce_xdl_mean_meansquare + example_gemm_mean_meansquare_xdl_fp16 + example_gemm_mean_meansquare_xdl_fp32 + example_gemm_mean_meansquare_xdl_bf16 + example_gemm_add_addsquare_xdl_int8) + +add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + +add_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb3832a668ff9829ee76e87e6ea8988b40e6980b --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +// DataType +using ADataType = F16; +using BDataType = F16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using D1Layout = Row; +using ELayout = D1Layout; + +// Elementwise op +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| +//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpPerf(float ave_time, int M, int N, int K) +{ + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec + << " GB/s, " << std::endl; +} + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + +int main() +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 1024; + ck::index_t StrideE = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor d0_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_m(f_host_tensor_descriptor1d(M, 1)); + Tensor r1_m(f_host_tensor_descriptor1d(M, 1)); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d0_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize()); + DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{N, N}; + + // Prepare GEMM, mean, mean_square + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + {r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + // init reducetion buffer to 0 + r0_device_buf.SetZero(); + r1_device_buf.SetZero(); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool do_verification = true; + bool pass = true; + + if(do_verification) + { + auto I0 = ck::Number<0>{}; + auto I1 = ck::Number<1>{}; + + Tensor e_m_n_host(e_m_n.mDesc); + Tensor r0_m_host(r0_m.mDesc); + Tensor r1_m_host(r1_m.mDesc); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = R0ThreadReduceOp{}; + auto reduce1_op = R1ThreadReduceOp{}; + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType square_e_val; + + auto e_val = ck::type_convert(e_m_n_host(m, n)); + auto d0_val = ck::type_convert(d0_n(n)); + auto d1_val = ck::type_convert(d1_m_n(m, n)); + cde_element_op(e_val, e_val, d0_val, d1_val); + e_m_n_host(m, n) = ck::type_convert(e_val); + + auto e_val_reduce = ck::type_convert(e_val); + qs_element_op[I1](square_e_val, e_val_reduce); + + reduce0_op(reduce0_acc, e_val_reduce); + reduce1_op(reduce1_acc, square_e_val); + } + + rs_element_op[I0](reduce0_acc, reduce0_acc); + rs_element_op[I1](reduce1_acc, reduce1_acc); + r0_m_host(m) = ck::type_convert(reduce0_acc); + r1_m_host(m) = ck::type_convert(reduce1_acc); + } + + e_device_buf.FromDevice(e_m_n.mData.data()); + r0_device_buf.FromDevice(r0_m.mData.data()); + r1_device_buf.FromDevice(r1_m.mData.data()); + + pass = ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results c", 1e-2, 1e-2); + pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); + pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); + } + + bool time_kernel = true; + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + DumpPerf( + ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1248002f751abce96072ee64283ca00022a8713 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = INT8; +using BDataType = INT8; +using GemmAccDataType = INT32; +using CShuffleDataType = INT32; +using DsDataType = ck::Tuple<>; +using EDataType = INT8; +using ReduceAccDataType = INT32; +using R0DataType = INT32; +using R1DataType = INT32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using Square = ck::tensor_operation::element_wise::UnarySquare; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using namespace ck::literals; + +template +bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideE, + bool do_verification, + int init_method, + bool time_kernel) +{ + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_m(f_host_tensor_descriptor1d(M, 1)); + Tensor r1_m(f_host_tensor_descriptor1d(M, 1)); + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize()); + DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{}; + + // Prepare GEMM, add, add_square + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + {r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + // init reducetion buffer to 0 + r0_device_buf.SetZero(); + r1_device_buf.SetZero(); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + auto I0 = ck::Number<0>{}; + auto I1 = ck::Number<1>{}; + + Tensor e_m_n_host(e_m_n.mDesc); + Tensor r0_m_host(r0_m.mDesc); + Tensor r1_m_host(r1_m.mDesc); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = RsThreadReduceOp{}[I0]; + auto reduce1_op = RsThreadReduceOp{}[I1]; + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.template GetIdentityValue(); + auto reduce1_acc = reduce1_op.template GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType square_e_val; + auto e_val = ck::type_convert(e_m_n_host(m, n)); + qs_element_op[I1](square_e_val, e_val); + + reduce0_op(reduce0_acc, e_val); + reduce1_op(reduce1_acc, square_e_val); + } + + r0_m_host(m) = ck::type_convert(reduce0_acc); + r1_m_host(m) = ck::type_convert(reduce1_acc); + } + e_device_buf.FromDevice(e_m_n.mData.data()); + + Tensor e_m_n_host_converted(e_m_n_host); + + pass = ck::utils::check_err( + e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); + + r0_device_buf.FromDevice(r0_m.mData.data()); + r1_device_buf.FromDevice(r1_m.mData.data()); + + pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); + pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); + + if(pass) + { + std::cout << "Success!" << std::endl; + } + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + std::size_t flop = 2_uz * M * N * K + 3_uz * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec + << " GB/s, " << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 512; + + ck::index_t StrideA = 512; + ck::index_t StrideB = 512; + ck::index_t StrideE = 1152; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_gemm_reduce_add_addsquare_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2feffeb8953df030ad2e30cb942bec529b2cac1 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = BF16; +using BDataType = BF16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 8, // ABlockTransfer SrcScalarPerVector + 8, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 8, // BBlockTransfer SrcScalarPerVector + 8, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 256; + + ck::index_t StrideA = 256; + ck::index_t StrideB = 256; + ck::index_t StrideE = 1152; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return run_gemm_reduce_max_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..363390add3e99c96ef3891bec995cda5f8c3e24b --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = F16; +using BDataType = F16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 8, // ABlockTransfer SrcScalarPerVector + 8, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 8, // BBlockTransfer SrcScalarPerVector + 8, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return run_gemm_reduce_max_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de6b7eb480b30d4347018da94ec518ff29fdbeec --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = F32; +using BDataType = F32; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ReduceAccDataType = F32; +using R0DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 4, // ABlockTransfer SrcScalarPerVector + 4, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 4, // BBlockTransfer SrcScalarPerVector + 4, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: Measure kernel execution time (1=ON, 0=Off)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + return run_gemm_reduce_max_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9666fc6622cdf3ab7cbab7071ebffd69d18b7668 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = INT4; +using ADataKernelType = INT8; +using BDataType = INT4; +using BDataKernelType = INT8; +using GemmAccDataType = INT32; +using CShuffleDataType = INT32; +using DsDataType = ck::Tuple<>; +using EDataType = INT4; +using EDataKernelType = INT8; +using ReduceAccDataType = INT32; +using R0DataType = INT32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 256; + + ck::index_t StrideA = 256; + ck::index_t StrideB = 256; + ck::index_t StrideE = 1152; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return run_gemm_reduce_max_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00e0b767a45dd979868b81635a79600d4695c779 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = INT8; +using BDataType = INT8; +using GemmAccDataType = INT32; +using CShuffleDataType = INT32; +using DsDataType = ck::Tuple<>; +using EDataType = INT8; +using ReduceAccDataType = INT32; +using R0DataType = INT32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using RsThreadReduceOp = ck::Tuple; +using RsGlobalReduceOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 512; + + ck::index_t StrideA = 512; + ck::index_t StrideB = 512; + ck::index_t StrideE = 1152; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return run_gemm_reduce_max_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..652c0e6ea6d2bbe3321de2deefb2c186d7252023 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = BF16; +using BDataType = BF16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 8, // ABlockTransfer SrcScalarPerVector + 8, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 8, // BBlockTransfer SrcScalarPerVector + 8, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 192; + + ck::index_t StrideA = 192; + ck::index_t StrideB = 192; + ck::index_t StrideE = 1152; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_gemm_reduce_mean_meansquare_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7eee24fed83988698e0e6d46f830f3a9dfdb0257 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = F16; +using BDataType = F16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 8, // ABlockTransfer SrcScalarPerVector + 8, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 8, // BBlockTransfer SrcScalarPerVector + 8, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_gemm_reduce_mean_meansquare_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c250b996928dd406bde9253ccbf40e6cd9d77347 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_reduce_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// DataType +using ADataType = F32; +using BDataType = F32; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; + +// Layout +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +// Elementwise op +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 4, // ABlockTransfer SrcScalarPerVector + 4, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 4, // BBlockTransfer SrcScalarPerVector + 4, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock + 1>; // RThread DstScalarPerVector _MPerBlock +// clang-format on +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" + << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_gemm_reduce_mean_meansquare_xdl( + M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62992de59765d3e05d935ac69894c4550a7ef472 --- /dev/null +++ b/3rdparty/composable_kernel/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -0,0 +1,491 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/io.hpp" +#include "ck/stream_config.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using INT4 = ck::int4_t; +#endif +using INT8 = std::int8_t; +using INT32 = std::int32_t; + +template +void DumpGemmReduceMaxPerf(float ave_time, int M, int N, int K) +{ + using namespace ck::literals; + + std::size_t flop = 2_uz * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(R0DataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec + << " GB/s, " << std::endl; +} + +template +void DumpGemmReduceMeanSquareMeanPerf(float ave_time, int M, int N, int K) +{ + using namespace ck::literals; + + std::size_t flop = 2_uz * M * N * K + M * (3_uz * N + 2_uz); + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec + << " GB/s, " << std::endl; +} + +template +auto run_gemm_reduce_max_xdl(ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideE, + bool do_verification, + int init_method, + bool time_kernel) +{ +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(ADataKernelType)); + static_assert(sizeof(BDataType) == sizeof(BDataKernelType)); + static_assert(sizeof(EDataType) == sizeof(EDataKernelType)); +#endif + using namespace ck::literals; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_m(f_host_tensor_descriptor1d(M, 1)); + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + } + + DeviceMem a_device_buf(sizeof(ADataKernelType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataKernelType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataKernelType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize()); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor a_m_k_converted = a_m_k.template CopyAsType(); + Tensor b_k_n_converted = b_k_n.template CopyAsType(); + + a_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_device_buf.ToDevice(b_k_n_converted.mData.data()); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{}; + + // Prepare GEMM, max + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + {r0_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + // [CAUTION]: launch_and_time_kernel will not initialize D. + // If we evaluate kernel multiple time but without initialize D. Verification will fail + r0_device_buf.SetValue(ck::NumericLimits::Lowest()); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + auto I0 = ck::Number<0>{}; + + Tensor e_m_n_host(e_m_n.mDesc); + Tensor r0_m_host(r0_m.mDesc); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = RsThreadReduceOp{}[I0]; + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.template GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto e_val = e_m_n_host(m, n); + reduce0_op(reduce0_acc, e_val); + }; + + r0_m_host(m) = ck::type_convert(reduce0_acc); + } + + e_device_buf.FromDevice(e_m_n.mData.data()); + Tensor e_m_n_host_converted(e_m_n_host); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor e_m_n_device_converted(e_m_n); + pass = ck::utils::check_err(e_m_n_device_converted, + e_m_n_host_converted, + "Error: Incorrect results c", + 1e-2, + 1e-2); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + pass = ck::utils::check_err( + e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); + } + + r0_device_buf.FromDevice(r0_m.mData.data()); + pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); + + if(pass) + { + std::cout << "Success!" << std::endl; + } + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + DumpGemmReduceMaxPerf(ave_time, M, N, K); + } + + return pass ? 0 : 1; +} + +template +bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideE, + bool do_verification, + int init_method, + bool time_kernel) +{ +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(ADataKernelType)); + static_assert(sizeof(BDataType) == sizeof(BDataKernelType)); + static_assert(sizeof(EDataType) == sizeof(EDataKernelType)); +#endif + using namespace ck::literals; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_m(f_host_tensor_descriptor1d(M, 1)); + Tensor r1_m(f_host_tensor_descriptor1d(M, 1)); + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + } + + DeviceMem a_device_buf(sizeof(ADataKernelType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataKernelType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataKernelType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize()); + DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize()); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor a_m_k_converted = a_m_k.template CopyAsType(); + Tensor b_k_n_converted = b_k_n.template CopyAsType(); + + a_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_device_buf.ToDevice(b_k_n_converted.mData.data()); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{N, N}; + + // Prepare GEMM, mean, mean_square + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + {r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + // init reducetion buffer to 0 + r0_device_buf.SetZero(); + r1_device_buf.SetZero(); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + auto I0 = ck::Number<0>{}; + auto I1 = ck::Number<1>{}; + + Tensor e_m_n_host(e_m_n.mDesc); + Tensor r0_m_host(r0_m.mDesc); + Tensor r1_m_host(r1_m.mDesc); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = RsThreadReduceOp{}[I0]; + auto reduce1_op = RsThreadReduceOp{}[I1]; + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.template GetIdentityValue(); + auto reduce1_acc = reduce1_op.template GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType square_e_val; + auto e_val = ck::type_convert(e_m_n_host(m, n)); + qs_element_op[I1](square_e_val, e_val); + + reduce0_op(reduce0_acc, e_val); + reduce1_op(reduce1_acc, square_e_val); + } + + rs_element_op[I0](reduce0_acc, reduce0_acc); + rs_element_op[I1](reduce1_acc, reduce1_acc); + r0_m_host(m) = ck::type_convert(reduce0_acc); + r1_m_host(m) = ck::type_convert(reduce1_acc); + } + e_device_buf.FromDevice(e_m_n.mData.data()); + Tensor e_m_n_host_converted(e_m_n_host); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor e_m_n_device_converted(e_m_n); + pass = ck::utils::check_err(e_m_n_device_converted, + e_m_n_host_converted, + "Error: Incorrect results c", + 1e-2, + 1e-2); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + pass = ck::utils::check_err( + e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); + } + + r0_device_buf.FromDevice(r0_m.mData.data()); + r1_device_buf.FromDevice(r1_m.mData.data()); + + pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); + pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); + + if(pass) + { + std::cout << "Success!" << std::endl; + } + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + DumpGemmReduceMeanSquareMeanPerf( + ave_time, M, N, K); + } + + return pass; +} diff --git a/3rdparty/composable_kernel/example/17_convnd_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/example/17_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa4e65d965f16c7049a21dba93e881fa04cd5319 --- /dev/null +++ b/3rdparty/composable_kernel/example/17_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,5 @@ +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + +add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) +target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) diff --git a/3rdparty/composable_kernel/example/17_convnd_bwd_data/README.md b/3rdparty/composable_kernel/example/17_convnd_bwd_data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b5c8281ed8aea3dd649d2715289f09f2d9e4c60a --- /dev/null +++ b/3rdparty/composable_kernel/example/17_convnd_bwd_data/README.md @@ -0,0 +1,47 @@ +# Instructions for ```example_convnd_bwd_data_xdl``` + +## Run ```example_example_convnd_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: num_dim_spatial(1|2|3) +#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx +./bin/example_convnd_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128} +wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 128, 8} +arg.c_grid_desc_m_n_container_{ 175232, 128} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s +``` diff --git a/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26fa9e9821fc050c9ab69fe1a170b4c72a98dcd3 --- /dev/null +++ b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +int run_conv_bwd_data(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out(out_g_n_k_wos_desc); + + std::cout << "in: " << in_host.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + // do GEMM + auto conv = DeviceConvNdBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.GetOutputSpatialLengths(), + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + std::cout << "Not support,please check parameters or device"; + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device, in_host) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0896e977144867421b857afd0a58a239716712b --- /dev/null +++ b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_bwd_data_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +// clang-format off +using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Dl< +// ######| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +int main(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_param.num_dim_spatial_ == 1) + { + using InLayout = ctc::GNWC; + using WeiLayout = ctc::GKXC; + using OutLayout = ctc::GNWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<1, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<1>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 2) + { + using InLayout = ctc::GNHWC; + using WeiLayout = ctc::GKYXC; + using OutLayout = ctc::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<2, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<2>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 3) + { + using InLayout = ctc::GNDHWC; + using WeiLayout = ctc::GKZYXC; + using OutLayout = ctc::GNDHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<3, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<3>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4f2c1f02bb9010ab3f5d97166819732e911b68a --- /dev/null +++ b/3rdparty/composable_kernel/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_bwd_data_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl< + NDimSpatial, // NDimSpatial + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +int main(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_param.num_dim_spatial_ == 1) + { + using InLayout = ctc::GNWC; + using WeiLayout = ctc::GKXC; + using OutLayout = ctc::GNWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<1, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<1>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 2) + { + using InLayout = ctc::GNHWC; + using WeiLayout = ctc::GKYXC; + using OutLayout = ctc::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<2, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<2>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 3) + { + using InLayout = ctc::GNDHWC; + using WeiLayout = ctc::GKZYXC; + using OutLayout = ctc::GNDHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<3, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<3>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/18_batched_gemm_reduce/CMakeLists.txt b/3rdparty/composable_kernel/example/18_batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..99fc0043d2803494522b446cabf4a3f98ca12a5a --- /dev/null +++ b/3rdparty/composable_kernel/example/18_batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + diff --git a/3rdparty/composable_kernel/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/3rdparty/composable_kernel/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2e3602a7bb22e3f31c66bb6045171229584fa14 --- /dev/null +++ b/3rdparty/composable_kernel/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using ReduceDataType = F32; +using ReducePtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceOp0 = ck::reduce::Add; +using ReduceOp1 = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceGlobalMemOps = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 2048; + ck::index_t N = 1920; + ck::index_t K = 2048; + + ck::index_t StrideA = 2048; + ck::index_t StrideB = 2048; + ck::index_t StrideC = 1920; + + ck::index_t BatchCount = 4; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + BatchCount = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); + exit(0); + } + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result({BatchCount, M}); + Tensor d1_g_m_host_result({BatchCount, M}); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result({BatchCount, M}); + Tensor d1_g_m_device_result({BatchCount, M}); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + // do GEMM + auto batched_gemm = DeviceBatchedGemmReduceInstance{}; + auto invoker = batched_gemm.MakeInvoker(); + auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); + + if(!batched_gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << batched_gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto reduce0_op = ReduceOp0{}; + auto reduce1_op = ReduceOp1{}; + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; + + UnaryIdenticElementOp{}(d0_val, c_val); + UnarySquareElementOp{}(d1_val, c_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); + } + } + + pass = ck::utils::check_err( + c_g_m_n_host_result, c_g_m_n_device_result, "Error: Incorrect results c") && + ck::utils::check_err(d0_g_m_device_result, + d0_g_m_host_result, + "Error: Incorrect results! D0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_g_m_device_result, + d1_g_m_host_result, + "Error: Incorrect results! D1", + 1e-3, + 1e-5); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/19_binary_elementwise/CMakeLists.txt b/3rdparty/composable_kernel/example/19_binary_elementwise/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..39646e0ab5e05415dbed5dfefd9d644f1dcf2988 --- /dev/null +++ b/3rdparty/composable_kernel/example/19_binary_elementwise/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) +add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) +add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) +add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file diff --git a/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9eae27ca6e92996e87f328cb1a32500dc7f1d658 --- /dev/null +++ b/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + Add, + 2, + 8, + ck::Sequence<8, 8>, + ck::Sequence<8>>; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + auto Amn = A(m, n); + ctype Cmn = 0; + if constexpr(broadcastDim == 0) + { + auto Bn = B(n); + functor(Cmn, Amn, Bn); + } + else + { + auto Bm = B(m); + functor(Cmn, Amn, Bm); + } + C(m, n) = Cmn; + } + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + using namespace ck::literals; + + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + }; + + Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor b_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + a_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpaceSize()); + DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize()); + + a_m_n_device_buf.ToDevice(a_m_n.mData.data()); + b_n_device_buf.ToDevice(b_n.mData.data()); + + std::array input = {a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths = {M, N}; + std::array a_strides = {Stride, 1}; + std::array b_strides = {0, 1}; + std::array c_strides = {Stride, 1}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + Tensor host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + host_broadcast2D, Tensor, Tensor, Add, 0>( + host_c_m_n, a_m_n, b_n, M, N, Add{}); + + pass &= ck::utils::check_err(c_m_n, host_c_m_n, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..813d38b01ee9c3c44e3eef74bf445d18d87a4f26 --- /dev/null +++ b/3rdparty/composable_kernel/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + Add, + 3, + 8, + ck::Sequence<1, 8>, + ck::Sequence<8>>; + +template +void host_broadcast3D_am_bmnk(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + for(std::size_t k = 0; k < shape[2]; ++k) + { + auto a_val = A(m); + auto b_val = B(m, n, k); + ctype c_val = 0; + functor(c_val, a_val, b_val); + C(m, n, k) = c_val; + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector mnk = {4, 16, 32}; + ck::index_t M = mnk[0]; + + Tensor a_m({M}); + Tensor b_m_n_k(mnk); + Tensor c_m_n_k(mnk); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize()); + DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpaceSize()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer(), + b_m_n_k_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_n_k_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths; + std::array a_strides = {1, 0, 0}; + std::array b_strides; + std::array c_strides; + + ck::ranges::copy(mnk, abc_lengths.begin()); + ck::ranges::copy(b_m_n_k.mDesc.GetStrides(), b_strides.begin()); + ck::ranges::copy(c_m_n_k.mDesc.GetStrides(), c_strides.begin()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); + Tensor host_c_m_n_k(mnk); + + host_broadcast3D_am_bmnk, Tensor, Tensor, Add>( + host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); + + pass &= + ck::utils::check_err(c_m_n_k, host_c_m_n_k, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_1d.cpp b/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_1d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1ca9378d35edf4e3a89da6af92f48c76a2acede --- /dev/null +++ b/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + Add, + 1, + 8, + ck::Sequence<8, 8>, + ck::Sequence<8>>; + +template +void host_elementwise1D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + auto Am = A(m); + auto Bm = B(m); + ctype Cm = 0; + functor(Cm, Am, Bm); + C(m) = Cm; + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor b_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize()); + DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpaceSize()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_device_buf.ToDevice(b_m.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths = {M}; + std::array a_strides = {1}; + std::array b_strides = {1}; + std::array c_strides = {1}; + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, Tensor, Tensor, Add>( + host_c_m, a_m, b_m, M, Add{}); + + pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_4d.cpp b/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_4d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27e10014815f8adce491cc69b500b48b2fa3957a --- /dev/null +++ b/3rdparty/composable_kernel/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + Add, + 4, + 8, + ck::Sequence<8, 8>, + ck::Sequence<8>>; + +template +void host_elementwise4D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t n = 0; n < shape[0]; ++n) + for(std::size_t c = 0; c < shape[1]; ++c) + for(std::size_t h = 0; h < shape[2]; ++h) + for(std::size_t w = 0; w < shape[3]; ++w) + { + auto a_val = A(n, c, h, w); + auto b_val = B(n, c, h, w); + ctype c_val = 0; + functor(c_val, a_val, b_val); + C(n, c, h, w) = c_val; + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector nchw = {4, 16, 32, 32}; + + Tensor a(nchw); + Tensor b(nchw); + Tensor c(nchw); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + b_device_buf.ToDevice(b.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer()}; + std::array output = {c_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths; + std::array a_strides; + std::array b_strides; + std::array c_strides; + + ck::ranges::copy(nchw, abc_lengths.begin()); + ck::ranges::copy(a.mDesc.GetStrides(), a_strides.begin()); + ck::ranges::copy(b.mDesc.GetStrides(), b_strides.begin()); + ck::ranges::copy(c.mDesc.GetStrides(), c_strides.begin()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c.mData.data()); + Tensor host_c(nchw); + + host_elementwise4D, Tensor, Tensor, Add>( + host_c, a, b, nchw, Add{}); + + pass &= ck::utils::check_err(c, host_c, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..557f7971fa38ee7158c3536a71d325585d47c2d9 --- /dev/null +++ b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -0,0 +1,8 @@ +add_custom_target(example_grouped_conv_bwd_weight) + +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + + +add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 + example_grouped_conv_bwd_weight_xdl_bf16) diff --git a/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/common.hpp b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2a8bed59d61fed543124cf8bbb340e18a94dc01 --- /dev/null +++ b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/common.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp" + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +template +struct CommonLayoutSetting +{ + using InputLayout = InputLay; + using WeightLayout = WeightLay; + using OutputLayout = OutputLay; +}; + +template +struct CommonLayoutSettingSelector; + +namespace ctl = ck::tensor_layout::convolution; + +template <> +struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<2> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<3> final + : CommonLayoutSetting +{ +}; + +template +using InputLayout = typename CommonLayoutSettingSelector::InputLayout; + +template +using WeightLayout = typename CommonLayoutSettingSelector::WeightLayout; + +template +using OutputLayout = typename CommonLayoutSettingSelector::OutputLayout; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_param) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_param = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9035309c98ef2c598012c845f07a30bae68ce84f --- /dev/null +++ b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = BF16; +// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory +using WeiDataType = F32; +using OutDataType = BF16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6791b0bf68ed556c3ce93456061bd9d5404fe798 --- /dev/null +++ b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..5264c856fefbbed2a26b105535063b223926f67a --- /dev/null +++ b/3rdparty/composable_kernel/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< + NDimSpatial, // NDimSpatial + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +template +bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_param) +{ + constexpr ck::index_t split_k = 2; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< + InputLayout>(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< + WeightLayout>(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< + OutputLayout>(conv_param); + + Tensor in(in_g_n_c_wis_desc); + Tensor wei_host_result(wei_g_k_c_xs_desc); + Tensor wei_device_result(wei_g_k_c_xs_desc); + Tensor out(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei_host_result.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + // init to 0 + wei_device_buf.SetZero(); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; + + range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); + range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); + range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); + range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); + range_copy(conv_param.input_left_pads_, begin(input_left_pads)); + range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + + // do GEMM + auto conv = DeviceConvBwdWeightInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.C_, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + if(!conv.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return false; + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_conv = HostConvBwdWeightInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei_host_result, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_device_result.mData.data()); + + return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); + } + + return true; +} + +bool run_grouped_conv_bwd_weight_example(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return false; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param); + } + + return false; +} diff --git a/3rdparty/composable_kernel/example/21_gemm_layernorm/CMakeLists.txt b/3rdparty/composable_kernel/example/21_gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..78d3a5d02a5be1a0c7e6475aa080a5c75047ac9b --- /dev/null +++ b/3rdparty/composable_kernel/example/21_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_single_kernel_fp16 gemm_xdl_layernorm_single_kernel_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e37555e7612f088cac22d6fbb8b9c8805e73d453 --- /dev/null +++ b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +// DataType +using ADataType = F16; +using BDataType = F16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +// Layout +using ALayout = Row; +using BLayout = Col; +using D1Layout = Row; +using ELayout = D1Layout; + +// Elementwise op +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddReluAdd; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| +//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, // x(gemm_out), mean, meansquare, gamma, beta + ck::Tuple, // y + NormalizeFunctor, + 2, + 8, // MPerthread + ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta + ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out) + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& bias_n, + const Tensor& d1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + AElementOp a_element_op, + BElementOp b_element_op, + CDEElementOp cde_element_op, + int M, + int N) +{ + + int StrideE = N; + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = Div{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + // c = activation(c + bias) + c1_functor(c1) + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + auto acc = ck::type_convert(e_m_n(m, n)); + cde_element_op(e_m_n(m, n), acc, bias_n(n), d1_m_n(m, n)); + } + + // reduce_mean and reduce_square_mean + auto r0Op = R0ThreadReduceOp{}; + auto r1Op = R1ThreadReduceOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = r0Op.GetIdentityValue(); + auto mean_square_acc = r1Op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto e_val = ck::type_convert(e_m_n(m, n)); + ReduceAccDataType square_e_val = 0; + Square{}(square_e_val, e_val); + + r0Op(mean_acc, e_val); + r1Op(mean_square_acc, square_e_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(mean_square_acc, mean_square_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(mean_square_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + LayerNormOutDataType out_val = 0; + layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n)); + out_m_n(m, n) = out_val; + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(D0DataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M; + + std::size_t normalize_num_byte = sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 1024; + ck::index_t StrideE = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, ELayout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_Mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + bias_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(D0DataType) * bias_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize()); + DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) * + r1_MeanSquare_m.mDesc.GetElementSpaceSize()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{N, N}; + + // Prepare GEMM, mean, mean_square + auto gemmReduce = DeviceOpInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + // init reducetion buffer to 0 + r0_Mean_device_buf.SetZero(); + r1_MeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + std::array input = {e_device_buf.GetDeviceBuffer(), + r0_Mean_device_buf.GetDeviceBuffer(), + r1_MeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + + std::array xyLengths = {M, N}; + std::array xyStrides = {StrideE, 1}; + + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument_ptr = + normalize.MakeArgumentPointer(xyLengths, + {xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}}, + {xyStrides}, + input, + output, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + bias_n, + d1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + cde_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n, + host_layerNorm_m_n, + "Error: Incorrect results layerNorm_m_n", + 1e-2, + 1e-2); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..282c8763eba2de02a4d9b0ebfa3ffc20c629ff4a --- /dev/null +++ b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +// DataType +using ADataType = F16; +using BDataType = F16; +using GemmAccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ReduceAccDataType = F32; +using R0DataType = F32; +using R1DataType = F32; +using RsDataType = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +// Layout +using ALayout = Row; +using BLayout = Col; +using D1Layout = Row; +using ELayout = D1Layout; + +// Elementwise op +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; +using QsElementOp = ck::Tuple; +using RsElementOp = ck::Tuple; + +// ReduceOp +using R0ThreadReduceOp = ck::reduce::Add; +using R1ThreadReduceOp = ck::reduce::Add; +using RsThreadReduceOp = ck::Tuple; + +static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd; +using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| +//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, // x(gemm_out), mean, + // meansquare, + // gamma, beta + ck::Tuple, // y + NormalizeFunctor, + 2, + 8, // MPerthread + ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta + ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out) + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& gamma_n, + const Tensor& beta_n, + AElementOp a_element_op, + BElementOp b_element_op, + CDEElementOp c_element_op, + int M, + int N) +{ + int StrideE = N; + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = Div{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + // reduce_mean and reduce_square_mean + auto r0Op = R0ThreadReduceOp{}; + auto r1Op = R1ThreadReduceOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = r0Op.GetIdentityValue(); + auto mean_square_acc = r1Op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto e_val = ck::type_convert(e_m_n(m, n)); + ReduceAccDataType square_e_val = 0; + Square{}(square_e_val, e_val); + + r0Op(mean_acc, e_val); + r1Op(mean_square_acc, square_e_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(mean_square_acc, mean_square_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(mean_square_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + LayerNormOutDataType out_val = 0; + layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n)); + out_m_n(m, n) = out_val; + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M; + + std::size_t normalize_num_btye = sizeof(EDataType) * M * N + sizeof(R0DataType) * M + + sizeof(R1DataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor r0_Mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); + DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize()); + DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) * + r1_MeanSquare_m.mDesc.GetElementSpaceSize()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + auto qs_element_op = QsElementOp{}; + auto rs_element_op = RsElementOp{N, N}; + + // Prepare GEMM, mean, mean_square + auto gemmReduce = DeviceOpInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = gemmReduce.MakeArgument( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()}, + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + r0_Mean_device_buf.SetZero(); + r1_MeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + std::array input = {e_device_buf.GetDeviceBuffer(), + r0_Mean_device_buf.GetDeviceBuffer(), + r1_MeanSquare_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer()}; + std::array output = {layerNorm_device_buf.GetDeviceBuffer()}; + + std::array xyLengths = {M, N}; + std::array xyStrides = {StrideE, 1}; + + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument_ptr = + normalize.MakeArgumentPointer(xyLengths, + {xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}}, + {xyStrides}, + input, + output, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device, exiting"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + cde_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err( + layerNorm_m_n, host_layerNorm_m_n, "Error: Incorrect results d1", 1e-3, 1e-3); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c3e36be6a7ab92b48c578a07c6f4a2462b095a9 --- /dev/null +++ b/3rdparty/composable_kernel/example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +// This example demonstrate a single kernel that runs GEMM layer and laynorm in one fused kernel +// +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using C0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +struct Relu +{ + template + __host__ __device__ void operator()(OutT& y, const InT& x) const + { + y = x > 0 ? x : 0; + } +}; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +// Elementwise operation that operates on the output of matrix multiplication +// i.e., AccElementOp(A * B + bias) +using AccElementOp = Relu; +// Elementwise operation that operates on the output of layer normalization +using CElementOp = Relu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle +//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadCopy| +//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>; +// clang-format on + +using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 128; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 128; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_bias({N}); + Tensor c0_m_n_add(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_n_gamma({N}); + Tensor c0_n_beta({N}); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n_bias: " << c0_n_bias.mDesc << std::endl; + std::cout << "c0_m_n_add: " << c0_m_n_add.mDesc << std::endl; + std::cout << "c0_n_gamma: " << c0_n_gamma.mDesc << std::endl; + std::cout << "c0_n_beta: " << c0_n_beta.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n_add.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n_gamma.GenerateTensorValue(GeneratorTensor_2{0, 2}); + c0_n_beta.GenerateTensorValue(GeneratorTensor_2{0, 5}); + c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpaceSize()); + DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpaceSize()); + DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpaceSize()); + DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c0_bias_buf.ToDevice(c0_n_bias.mData.data()); + c0_add_buf.ToDevice(c0_m_n_add.mData.data()); + c0_gamma_buf.ToDevice(c0_n_gamma.mData.data()); + c0_beta_buf.ToDevice(c0_n_beta.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto acc_element_op = AccElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_add_buf.GetDeviceBuffer()), + static_cast(c0_bias_buf.GetDeviceBuffer()), + static_cast(c0_gamma_buf.GetDeviceBuffer()), + static_cast(c0_beta_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // extra 6MN flops due to: bias + add + gamma + beta + norm_sub + norm_div, + // excluding reduction steps + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(6) * M * N; + // extra MN and 3N due to c0_add (MxN), bias (1xN), gamma (1xN), beta (1xN) + std::size_t bytes = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * 2 * M * N + sizeof(C0DataType) * 3 * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + auto ref_gemm = ReferenceInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c"); + } + else if constexpr(std::is_same::value) + { + pass &= ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c", 1e-2, 1e-2); + } + } + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/CMakeLists.txt b/3rdparty/composable_kernel/example/22_cgemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..156456115611ca12744e7773c199138733b88cd7 --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/CMakeLists.txt @@ -0,0 +1,17 @@ +add_custom_target(example_cgemm_xdl) + +add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) +add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) + +add_dependencies(example_cgemm_xdl + example_cgemm_xdl_bf16 + example_cgemm_xdl_fp16 + example_cgemm_xdl_fp32 + example_cgemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) +endif() diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_bf16.cpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92ed90ce4ab3ec09da2e053ceebe4d4ed6bcf36d --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using CDataType = BF16; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 416; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(0); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_common.hpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6aa06b7c32cb476b61a88235e13a12a3d1a15db5 --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_common.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = std::int8_t; +using INT32 = std::int32_t; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using INT4 = ck::int4_t; +#endif + +template +bool run_cgemm_xdl(ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + bool do_verification, + int init_method, + bool time_kernel) +{ +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + static_assert(sizeof(ck::int4_t) == sizeof(int8_t), + "sizeof ck::int4_t and int8_t is different!"); + static_assert(sizeof(ADataType) == sizeof(KernelADataType), + "sizeof ADataType and KernelADataType is different!"); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType), + "sizeof BDataType and KernelBDataType is different!"); + static_assert(sizeof(CDataType) == sizeof(KernelCDataType), + "sizeof CDataType and KernelCDataType is different!"); +#endif + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_real_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; + std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; + std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl; + std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl; + std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; + std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_m_k_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + auto cgemm = DeviceCGemmInstance{}; + + DeviceMem a_m_k_real_device_buf(sizeof(KernelADataType) * + a_m_k_real.mDesc.GetElementSpaceSize()); + DeviceMem a_m_k_imag_device_buf(sizeof(KernelADataType) * + a_m_k_imag.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_real_device_buf(sizeof(KernelBDataType) * + b_k_n_real.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_imag_device_buf(sizeof(KernelBDataType) * + b_k_n_imag.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_real_device_buf(sizeof(KernelCDataType) * + c_m_n_real_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_imag_device_buf(sizeof(KernelCDataType) * + c_m_n_imag_device_result.mDesc.GetElementSpaceSize()); + DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor a_m_k_real_converted(a_m_k_real); + Tensor a_m_k_imag_converted(a_m_k_imag); + Tensor b_k_n_real_converted(b_k_n_real); + Tensor b_k_n_imag_converted(b_k_n_imag); + + a_m_k_real_device_buf.ToDevice(a_m_k_real_converted.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag_converted.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real_converted.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag_converted.mData.data()); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); + } + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + // do GEMM + auto invoker = cgemm.MakeInvoker(); + auto argument = + cgemm.MakeArgument(static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), + static_cast(workspace_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!cgemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_cgemm with the specified compilation parameters does " + "not support this CGEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(8) * M * N * K; + std::size_t num_btype = + std::size_t(2) * + (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << cgemm.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n_real_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto ref_cgemm = ReferenceCGemmInstance{}; + auto ref_invoker = ref_cgemm.MakeInvoker(); + auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real_host_result, + c_m_n_imag_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data()); + c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data()); + + bool result = true; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + const Tensor c_m_n_real_device_result_converted(c_m_n_real_device_result); + const Tensor c_m_n_imag_device_result_converted(c_m_n_imag_device_result); + + result = ck::utils::check_err(c_m_n_real_device_result_converted, + c_m_n_real_host_result, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + result = result && ck::utils::check_err( + c_m_n_imag_device_result_converted, + c_m_n_imag_host_result, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + result = ck::utils::check_err(c_m_n_real_device_result, + c_m_n_real_host_result, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + result = result && ck::utils::check_err( + c_m_n_imag_device_result, + c_m_n_imag_host_result, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + + return result; + } + return true; +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11373736ee8b37efb5f4082253d46f02074f011f --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(0); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp32.cpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f45c18c4818726f179adda70dc12b5ea7c45b9d --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = F32; +using BDataType = F32; +using CDataType = F32; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 4, // index_t ABlockTransferSrcScalarPerVector + 4, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 4, // index_t BBlockTransferSrcScalarPerVector + 4, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(0); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c26a83baafd375b477845fbdd793c3918d7b7dc5 --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int4.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = INT4; +using BDataType = INT4; +using CDataType = INT4; +using AccDataType = INT32; +using CShuffleDataType = INT32; + +using KernelADataType = INT8; +using KernelBDataType = INT8; +using KernelCDataType = INT8; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // CGEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideC = N; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f24189861d8468848d28b9c0480da7c3f2c4fdd --- /dev/null +++ b/3rdparty/composable_kernel/example/22_cgemm/cgemm_xdl_int8.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = INT8; +using BDataType = INT8; +using CDataType = INT8; +using AccDataType = INT32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(0); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/3rdparty/composable_kernel/example/23_softmax/CMakeLists.txt b/3rdparty/composable_kernel/example/23_softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..dafe65521aab12d55f3ead8ca1ae4f78f9810bde --- /dev/null +++ b/3rdparty/composable_kernel/example/23_softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax_blockwise softmax_blockwise.cpp) \ No newline at end of file diff --git a/3rdparty/composable_kernel/example/23_softmax/README.md b/3rdparty/composable_kernel/example/23_softmax/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37c43e9b552245cb05a50750ed93f8e730353ff3 --- /dev/null +++ b/3rdparty/composable_kernel/example/23_softmax/README.md @@ -0,0 +1,18 @@ +# Instructions for ```example_softmax_blockwise``` + +## Run ```example_softmax_blockwise``` +```bash +# -D : input 3-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +example_softmax_blockwise -D 4,128,2048 -v 1 1 1 +``` + +Result +``` +launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8> +``` diff --git a/3rdparty/composable_kernel/example/23_softmax/softmax_blockwise.cpp b/3rdparty/composable_kernel/example/23_softmax/softmax_blockwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8854bf047be82b0511edd7d09032fccb99a7e8db --- /dev/null +++ b/3rdparty/composable_kernel/example/23_softmax/softmax_blockwise.cpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 3; +constexpr int NumReduceDim = 1; + +using DeviceInstance = DeviceSoftmaxImpl; // OutScalarPerVector + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {8, 128, 2048}; + std::vector scales = {2.0f, 2.0f}; + + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + // Example: batched gemm C[G, M, N] applies max/sum reduction along N internally + const std::vector invariantDims{0, 1}; + const std::vector reduceDims{2}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + Tensor in(args.inLengths); + Tensor out_ref(args.inLengths); + Tensor out(args.inLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + AccDataType alpha = args.scales[0]; + AccDataType beta = args.scales[1]; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + // std::cout << "beta = " << beta << std::endl; + // LogRangeAsType(std::cout << "tensor in: " , in.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "tensor prior out: " , out.mData, ",") << std::endl; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + if(args.do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceSoftmax; + ReferenceInstance ref; + auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims); + auto invoker = ref.MakeInvoker(); + invoker.Run(ref_arg); + // LogRangeAsType(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl; + }; + + std::vector i_inLengths; + std::vector i_inStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + + auto device_instance = DeviceInstance{}; + + std::cout << i_inLengths.size() << ", " << i_inStrides.size() << std::endl; + + auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, + i_inStrides, + reduceDims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + return 1; + }; + + std::string instance_name = device_instance.GetTypeString(); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + + bool pass = true; + if(args.do_verification) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + out_dev.FromDevice(out.mData.data()); + // LogRangeAsType(std::cout << "tensor out: " , out.mData, ",") << std::endl; + pass = pass && ck::utils::check_err(out, out_ref); + }; + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name + << std::endl; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/24_batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7962576e875dd6b7addd659dd9b9f3b6c1504a41 --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/CMakeLists.txt @@ -0,0 +1,17 @@ +add_custom_target(example_batched_gemm_xdl) + +add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) + +add_dependencies(example_batched_gemm_xdl + example_batched_gemm_xdl_fp32 + example_batched_gemm_xdl_fp16 + example_batched_gemm_xdl_bfp16 + example_batched_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) +endif() diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c684c13d0dc71e408518adec62c5e2f8974a9872 --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1985f9af58ff2ba0d97db85ed5d9de7900ba10b --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a92a04dbe65b9fe6369e0ca1dd27f7dc983e118b --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e82cfe3248af558aed925731ceacb87b09a12ea --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int4.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using DsDataType = ck::Tuple<>; +using EDataType = ck::int4_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelEDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl + // clang-format off + < ALayout, //ALayout + BLayout, //BLayout + DsLayout, //DsLayout + ELayout, //ELayout + KernelADataType, //ADataType + KernelBDataType, //BDataType + AccDataType, //AccDataType + CShuffleDataType, //CShuffleDataType + DsDataType, //DsDataType + KernelEDataType, //EDataType + AElementOp, //AElementwiseOperation + BElementOp, //BElementwiseOperation + CDEElementOp, //CDEElementwiseOperation + GemmDefault, //GEMMSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl + 16>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad22227af5625f51e1051b95a357938dcdb6b57f --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/batched_gemm_xdl_int8.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; +using DsDataType = ck::Tuple<>; +using EDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/24_batched_gemm/run_batched_gemm_example.inc b/3rdparty/composable_kernel/example/24_batched_gemm/run_batched_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..21934add3162ce9508a82656847b056bea8402c6 --- /dev/null +++ b/3rdparty/composable_kernel/example/24_batched_gemm/run_batched_gemm_example.inc @@ -0,0 +1,240 @@ +#include + +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + ck::index_t batch_count = 16; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); + static_assert(sizeof(EDataType) == sizeof(KernelEDataType)); +#endif + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count] = problem_size; + + // GEMM shape + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); +#ifdef BUILD_INT4_EXAMPLE + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); +#else + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); +#endif + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_g_m_k_converted(a_g_m_k); + const Tensor b_g_k_n_converted(b_g_k_n); + + a_device_buf.ToDevice(a_g_m_k_converted.mData.data()); + b_device_buf.ToDevice(b_g_k_n_converted.mData.data()); +#else + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); +#endif + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + batch_count, + stride_A, + stride_B, + {}, + stride_C, + batch_stride_A, + batch_stride_B, + {}, + batch_stride_C, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(config.do_verification) + { + c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); + + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor e_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor e_device_result_converted(e_g_m_n_device_result); + pass &= ck::utils::check_err(e_device_result_converted, e_g_m_n_host_result); + +#else + pass = ck::utils::check_err( + e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c"); +#endif + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(EDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass ? 0 : 1; +} + +bool run_batched_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 256 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 64 * (dis(gen) + 2); + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + problem_size.batch_count = 16; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return run_batched_gemm(problem_size, config); +} diff --git a/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cbc3c007bc22622554fd93f4d8a829c9ab666dc4 --- /dev/null +++ b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) +add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c934d35019602e634da6f8b6d49c1c3b133c6800 --- /dev/null +++ b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 1; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 3; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G1_M2_N3_K1::Argument; + + float Run(const Argument& arg) + { + auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c; + }; + + make_ParallelTensorFunctor(f_gs_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M3_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + + ck::index_t M0 = 4; + ck::index_t M1 = 256; + + ck::index_t N0 = 4; + ck::index_t N1 = 16; + ck::index_t N2 = 32; + + ck::index_t K0 = 256; + + // A[M0, M1, M2, K0] + std::vector a_gs_ms_ks_lengths{G0, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1}; + // B[N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, N0, N1, N2, K0}; + std::vector b_gs_ns_ks_strides{N0 * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1}; + + // D[N0, M0, N1, M1, N2] + std::vector d_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2}; + std::vector d_gs_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1}; + // E[N0, M0, N1, M1, N2] + std::vector e_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2}; + std::vector e_gs_ms_ns_strides{ + M0 * M1 * N0 * N1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + std::size_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + std::size_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks, + b_gs_ns_ks, + c_gs_ms_ns_host_result, + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1) + { + for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n2) + { + cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), + c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), + d_gs_ms_ns(g0, m0, m1, n0, n1, n2)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98835f98fa6ecea9c38f1155f5dd03d0787fa0d2 --- /dev/null +++ b/3rdparty/composable_kernel/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 1; +static constexpr ck::index_t NumDimM = 3; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +template = + false> +struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G1_M3_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_gs_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G1_M3_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + + ck::index_t M0 = 4; + ck::index_t M1 = 8; + ck::index_t M2 = 256; + + ck::index_t N0 = 32; + ck::index_t N1 = 128; + + ck::index_t K0 = 1024; + + // A[M0, M1, M2, K0] + std::vector a_gs_ms_ks_lengths{G0, M0, M1, M2, K0}; + std::vector a_gs_ms_ks_strides{M0 * M1 * M2 * K0, M1 * M2 * K0, M2 * K0, K0, 1}; + + // B[N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[M0, N0, M1, N1, M2] + std::vector d_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1}; + std::vector d_gs_ms_ns_strides{N0 * N1, 0, 0, 0, N1, 1}; + + // E[M1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1}; + std::vector e_gs_ms_ns_strides{ + M0 * M1 * M2 * N1 * N0, N0 * M1 * N1, N1, M0 * N0 * M1 * N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks, + b_gs_ns_ks, + c_gs_ms_ns_host_result, + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1) + { + for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m2) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1), + c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1), + d_gs_ms_ns(g0, m0, m1, m2, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/26_contraction/CMakeLists.txt b/3rdparty/composable_kernel/example/26_contraction/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..87f4750e3bf5dd81c001718c388038a5840d8a88 --- /dev/null +++ b/3rdparty/composable_kernel/example/26_contraction/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) diff --git a/3rdparty/composable_kernel/example/26_contraction/README.md b/3rdparty/composable_kernel/example/26_contraction/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c88d93cf83a411e2499cd0be80e524c42db339c6 --- /dev/null +++ b/3rdparty/composable_kernel/example/26_contraction/README.md @@ -0,0 +1,20 @@ +# Instructions for ```example_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + +Result (MI100 @ dynammic freq, 46TFlops peak FP32) +``` +a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} +c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> +``` diff --git a/3rdparty/composable_kernel/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/3rdparty/composable_kernel/example/26_contraction/contraction_bilinear_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea105e4ff2bb9a33830fd1182007d658c52ba7e0 --- /dev/null +++ b/3rdparty/composable_kernel/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/26_contraction/contraction_scale_xdl_fp32.cpp b/3rdparty/composable_kernel/example/26_contraction/contraction_scale_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..26f176b0591d02a821fa783542e9538791e4cbcd --- /dev/null +++ b/3rdparty/composable_kernel/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -0,0 +1,409 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/27_layernorm/CMakeLists.txt b/3rdparty/composable_kernel/example/27_layernorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d96deae45e49acdc8b86f9973417d6a77e1e7e31 --- /dev/null +++ b/3rdparty/composable_kernel/example/27_layernorm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) diff --git a/3rdparty/composable_kernel/example/27_layernorm/layernorm_blockwise.cpp b/3rdparty/composable_kernel/example/27_layernorm/layernorm_blockwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..147307d9ef0df3c3473791382f242f8c217c2ebe --- /dev/null +++ b/3rdparty/composable_kernel/example/27_layernorm/layernorm_blockwise.cpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using AccDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + +int main() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = N; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + using namespace ck::literals; + + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + }; + + Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor gamma(f_host_tensor_descriptor1d(N, 1)); + Tensor beta(f_host_tensor_descriptor1d(N, 1)); + Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {M, N}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 1}, + {0, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + {1}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3); + } + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..44ab16894ce054a4b25b2efd12ebac9152b3bd58 --- /dev/null +++ b/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8e6501eadaee4e5b4156309afebdf16e11f3e8a --- /dev/null +++ b/3rdparty/composable_kernel/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimM = 3; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Packed; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// clang-format on + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M3_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto m2, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, m2, k0))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, m2, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3], + arg.e_ms_ns_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M3_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + std::size_t group_count = rand() % 16 + 1; + + // GEMM shape + std::vector> contraction_descs; + std::vector p_a, p_b; + std::vector> p_ds; + std::vector p_c; + + contraction_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + int M0 = 4 * (rand() % 4 + 1); + int M1 = 4 * (rand() % 4 + 1); + int M2 = 256; + + int N0 = 4 * (rand() % 4 + 1); + int N1 = 128; + + int K0 = 64 * (rand() % 4 + 1); + + // A[M0, M1, M2, K0] + std::vector a_ms_ks_lengths{M0, M1, M2, K0}; + std::vector a_ms_ks_strides{M1 * M2 * K0, M2 * K0, K0, 1}; + // B[N0, N1, K0] + std::vector b_ns_ks_lengths{N0, N1, K0}; + std::vector b_ns_ks_strides{N1 * K0, K0, 1}; +#if 0 + // D[M0, N0, M1, N1, M2] + std::vector d_ms_ns_lengths{M0, M1, M2, N0, N1}; + std::vector d_ms_ns_strides{0, 0, 0, N1, 1}; + // E[M0, N0, M1, N1, M2] + std::vector e_ms_ns_lengths{M0, M1, M2, N0, N1}; + std::vector e_ms_ns_strides{N0 * M1 * N1 * M2, N1 * M2, 1, M1 * N1 * M2, M2}; +#else + // D[M0, N0, M1, N1, M2] + std::vector d_ms_ns_lengths{M0, M1, M2, N0, N1}; + std::vector d_ms_ns_strides{0, 0, 0, N1, 1}; + // E[M0, N0, M1, N1, M2] + std::vector e_ms_ns_lengths{M0, M1, M2, N0, N1}; + std::vector e_ms_ns_strides{M1 * M2 * N0 * N1, M2 * N0 * N1, N0 * N1, N1, 1}; +#endif + + contraction_descs.push_back( + ck::tensor_operation::device::ContractionDesc<1>{a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + {d_ms_ns_lengths}, + {d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides}); + } + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d_tensors; + std::vector> e_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d_tensors_device, + e_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + e_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(std::size_t i = 0; i < contraction_descs.size(); i++) + { + const auto a_ms_ks_lengths = contraction_descs[i].a_ms_ks_lengths; + const auto a_ms_ks_strides = contraction_descs[i].a_ms_ks_strides; + + const auto b_ns_ks_lengths = contraction_descs[i].b_ns_ks_lengths; + const auto b_ns_ks_strides = contraction_descs[i].b_ns_ks_strides; + + const auto d_ms_ns_lengths = contraction_descs[i].ds_ms_ns_lengths[0]; + const auto d_ms_ns_strides = contraction_descs[i].ds_ms_ns_strides[0]; + + const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; + const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + ck::index_t M_ = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N_ = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K_ = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + a_tensors.push_back(a_ms_ks); + b_tensors.push_back(b_ns_ks); + d_tensors.push_back(d_ms_ns); + + // e_host_tensors.push_back(e_ms_ns_host_result); + e_device_tensors.push_back(e_ms_ns_device_result); + + flop += std::size_t(2) * M_ * K_ * N_; + + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_n_k: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc + << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_1{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_1{}); + d_tensors[i].GenerateTensorValue(GeneratorTensor_1{}); + } + } + + for(std::size_t i = 0; i < contraction_descs.size(); i++) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); + d_tensors_device.emplace_back(std::make_unique( + sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize())); + e_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + d_tensors_device[i]->ToDevice(d_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_ds.push_back({d_tensors_device[i]->GetDeviceBuffer()}); + p_c.push_back(e_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceOpInstanceKKNN{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_a, p_b, p_ds, p_c, contraction_descs, a_element_op, b_element_op, cde_element_op); + + DeviceMem contraction_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, contraction_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + for(std::size_t i = 0; i < group_count; i++) + { + const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; + const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); + + using ReferenceOpInstance = ReferenceContraction_M3_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_ms_ns_host_result, + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t m2 = 0; m2 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++m2) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, m2, n0, n1), + c_ms_ns_host_result(m0, m1, m2, n0, n1), + d_tensors[i](m0, m1, m2, n0, n1)); + } + } + } + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_ms_ns_host_result); + } + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..40470f27d427ac8194bb3cf55430b7fd33438597 --- /dev/null +++ b/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25d815b9cdfdddc29c699a1bc51dd5191a5246eb --- /dev/null +++ b/3rdparty/composable_kernel/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 256; + + ck::index_t N0 = 16; + ck::index_t N1 = 128; + + ck::index_t K0 = 64; + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = + ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..61b2b2f6f3a8c8fdd866f9f5a80f965293f0dd2e --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -0,0 +1,22 @@ +add_custom_target(example_grouped_conv_fwd_multiple_d) + +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + +add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) +add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) +add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) +add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 + + +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + +add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/README.md b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..739a0425a8c29419c3e31e433815c5230bd1a5cb --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/README.md @@ -0,0 +1,30 @@ +Command +```bash +arg1: verification (0=no, 1=yes) +arg2: initialization (0=no init, 1=integer value, 2=decimal value) +arg3: time kernel (0=no, 1=yes) +Following arguments (depending on number of spatial dims): + Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) + G, N, K, C, + , (ie Y, X for 2D) + , (ie Hi, Wi for 2D) + , (ie Sy, Sx for 2D) + , (ie Dy, Dx for 2D) + , (ie LeftPy, LeftPx for 2D) + , (ie RightPy, RightPx for 2D) + +./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1 +``` + +Result (MI100) +``` +in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192} +wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192} +bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} +residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} +out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<256, 128, 256, 16, Default> +``` diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/common.hpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d6d6dd6ff1caa652461114b7626128456295e962 --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using BF16 = ck::bhalf_t; +using FP16 = ck::half_t; +using FP32 = float; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using I4 = ck::int4_t; +#endif +using I8 = std::int8_t; +using I32 = std::int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +struct CommonLayoutSetting +{ + using InputLayout = InputLay; + using WeightLayout = WeightLay; + using OutputLayout = OutputLay; +}; + +template +struct CommonLayoutSettingSelector; + +namespace ctl = ck::tensor_layout::convolution; + +template <> +struct CommonLayoutSettingSelector<1> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<2> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<3> final + : CommonLayoutSetting +{ +}; + +template +using InputLayout = typename CommonLayoutSettingSelector::InputLayout; + +template +using WeightLayout = typename CommonLayoutSettingSelector::WeightLayout; + +template +using OutputLayout = typename CommonLayoutSettingSelector::OutputLayout; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_param) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_param = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.G_ * conv_param.C_ // wi + }); + + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1], + conv_param.input_spatial_lengths_[2]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] * + conv_param.G_ * conv_param.C_, // di + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k + 1, // c + conv_param.C_ // x + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y + conv_param.C_ // x + }); + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1], + conv_param.filter_spatial_lengths_[2]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // z + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y + conv_param.C_ // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_bias_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + 0, // k + 1, // c + 0 // x + }); + case 2: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + case 3: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // z + 0, // y + 0 // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.G_ * conv_param.K_ // wo + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] * + conv_param.G_ * conv_param.K_, // do + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ee300d073a28b44cf8177f64a7278d5e13ef3130 --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +// kernel data types +using InKernelDataType = BF16; +using WeiKernelDataType = BF16; +using AccDataType = FP32; +using CShuffleDataType = FP32; +using BiasKernelDataType = BF16; +using ResidualKernelDataType = BF16; +using OutKernelDataType = BF16; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a9df0b1e880c8f3bc954c07cd5de3639644c4e1 --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +// kernel data types +using InKernelDataType = FP16; +using WeiKernelDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasKernelDataType = FP16; +using ResidualKernelDataType = FP16; +using OutKernelDataType = FP16; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2906cc9dd1dbe22c8d420ce0929783fb0be0fa6 --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +// kernel data types +using InKernelDataType = FP32; +using WeiKernelDataType = FP32; +using AccDataType = FP32; +using CShuffleDataType = FP32; +using BiasKernelDataType = FP32; +using ResidualKernelDataType = FP32; +using OutKernelDataType = FP32; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d5a243e6b9282f9d362c82c0758776341df646f --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include "common.hpp" + +// kernel data types +using InKernelDataType = I8; +using WeiKernelDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I8; +using BiasKernelDataType = I8; +using ResidualKernelDataType = I8; +using OutKernelDataType = I8; + +// tensor data types +using InUserDataType = I4; +using WeiUserDataType = I4; +using OutUserDataType = I4; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#define BUILD_INT4_EXAMPLE +#include "run_grouped_conv_fwd_bias_relu_add_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eaf680fa438a14ee402e3805c18692ffdcf78c7e --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +// kernel data types +using InKernelDataType = I8; +using WeiKernelDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I8; +using BiasKernelDataType = I8; +using ResidualKernelDataType = I8; +using OutKernelDataType = I8; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6de1daa3d454cf556855322deb4b71198fae5d9b --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +// kernel data types +using InKernelDataType = FP16; +using WeiKernelDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using OutKernelDataType = FP16; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +#include "run_grouped_conv_fwd_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..4561156e0bd3a935c49751802e8c3037d8735228 --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +template +struct LayoutSetting +{ + using BiasLayout = BiasLay; + using ResidualLayout = ResidualLay; +}; + +template +struct LayoutSettingSelector; + +template <> +struct LayoutSettingSelector<1> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<2> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<3> final : LayoutSetting +{ +}; + +template +using BiasLayout = typename LayoutSettingSelector::BiasLayout; + +template +using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; + +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; + +template +using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +template +bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_param) +{ + static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial"); + + const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param); + const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param); + const auto bias_g_n_k_wos_desc = make_bias_descriptor(conv_param); + const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param); + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_n_k_wos_desc); + Tensor residual(bias_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor in_converted(in); + const Tensor wei_converted(wei); + const Tensor bias_converted(bias); + const Tensor residual_converted(residual); + + in_device_buf.ToDevice(in_converted.mData.data()); + wei_device_buf.ToDevice(wei_converted.mData.data()); + bias_device_buf.ToDevice(bias_converted.mData.data()); + residual_device_buf.ToDevice(residual_converted.mData.data()); +#else + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + residual_device_buf.ToDevice(residual.mData.data()); +#endif + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array d1_g_n_k_wos_lengths{}; + std::array d1_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer(), + residual_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 2>{ + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, + std::array, 2>{ + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = HostConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + OutElementOp{}(out_host(idx), c_host(idx), bias(idx), residual(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor out_device_converted(out_device); + + return ck::utils::check_err( + out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#else + return ck::utils::check_err( + out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#endif + } + + return true; +} + +bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return false; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + } + + return false; +} diff --git a/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..d087c31af5def501635ef173769851e111dec3cb --- /dev/null +++ b/3rdparty/composable_kernel/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple<>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; + +template +using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +template +bool run_grouped_conv_fwd(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_param) +{ + static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial"); + + const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param); + const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param); + const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param); + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor in_converted(in); + const Tensor wei_converted(wei); + + in_device_buf.ToDevice(in_converted.mData.data()); + wei_device_buf.ToDevice(wei_converted.mData.data()); +#else + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); +#endif + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_conv = HostConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor out_device_converted(out_device); + + return ck::utils::check_err( + out_device_converted.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); +#else + return ck::utils::check_err( + out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); +#endif + } + + return true; +} + +bool run_grouped_conv_fwd_example(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return false; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return run_grouped_conv_fwd<1>(config, conv_param); + case 2: return run_grouped_conv_fwd<2>(config, conv_param); + case 3: return run_grouped_conv_fwd<3>(config, conv_param); + } + + return false; +} diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d79248251c335c5bc7d5ee2a71c14ea3ac60ac88 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) + +if(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74e0e07e62a99e0700213e80a066d7e15128e516 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using B0DataType = BF16; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = BF16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_gemm_example.inc" + +int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5fadb8081e86a83784f6fad7e7405a7808b65b9 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_gemm_example.inc" + +int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0dd4e0914f479eefae604ceba8115432ce1d3803 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using B0DataType = F32; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 128, // Gemm1NPerBlock + 16, // Gemm1KPerBlock + 4, // AK1 + 4, // BK1 + 1, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 1, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_gemm_example.inc" + +int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fd93622a1b14a72d8159517a84b9f19d2959723 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using B0DataType = ck::int4_t; +using B1DataType = ck::int4_t; +using KernelADataType = int8_t; +using KernelB0DataType = int8_t; +using KernelB1DataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using CDataType = ck::int4_t; +using KernelCDataType = int8_t; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + KernelADataType, + KernelB0DataType, + KernelB1DataType, + KernelCDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 128, // Gemm1NPerBlock + 64, // Gemm1KPerBlock + 16, // AK1 + 16, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 4, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#define BUILD_INT4_EXAMPLE +#include "run_batched_gemm_gemm_example.inc" + +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) +static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + +int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15d98abab7dfa0f1168fd411eeae481054053000 --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using B0DataType = int8_t; +using B1DataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using CDataType = int8_t; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 128, // Gemm1NPerBlock + 64, // Gemm1KPerBlock + 16, // AK1 + 16, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 4, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_gemm_example.inc" + +int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..7e5f1614bcf6c6133d88f9d8dad63ec4be47c9ac --- /dev/null +++ b/3rdparty/composable_kernel/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_batched_gemm_gemm_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideC = -1; + ck::index_t BatchStrideA = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideC = -1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + } + else if(argc == 17) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + + StrideA = std::stoi(argv[9]); + StrideB0 = std::stoi(argv[10]); + StrideB1 = std::stoi(argv[11]); + StrideC = std::stoi(argv[12]); + + BatchStrideA = std::stoi(argv[13]); + BatchStrideB0 = std::stoi(argv[14]); + BatchStrideB1 = std::stoi(argv[15]); + BatchStrideC = std::stoi(argv[16]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " + "BatchStrideB0, BatchStrideB1, BatchStrideC\n"); + exit(0); + } + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_g_m_k_device_buf(sizeof(KernelADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(KernelB0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(KernelB1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_o_device_buf(sizeof(KernelCDataType) * + c_g_m_o_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_g_m_k_converted(a_g_m_k); + const Tensor b0_g_k_n_converted(b0_g_k_n); + const Tensor b1_g_n_o_converted(b1_g_n_o); + + a_g_m_k_device_buf.ToDevice(a_g_m_k_converted.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n_converted.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o_converted.mData.data()); +#else + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * + c_g_m_o_device_result.mDesc.GetElementSpaceSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), +#else + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(do_verification) + { + // Output of Gemm0 is input A of Gemm1 + Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_g_m_o_device_result_converted(c_g_m_o_host_result.mDesc); + + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result_converted.mData.data()); + + c_g_m_o_device_result = c_g_m_o_device_result_converted.CopyAsType(); +#else + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); +#endif + + return ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d9aaec85a503b060254ef748db77f4a7f289dae --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -0,0 +1,16 @@ +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) + +add_custom_target(example_gemm_scale_softmax_gemm) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) +add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0eb15653306f0179b1d132ddfcd130af6ba13fd6 --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f1db577c604418b7b1d546b5de251f5a20054a6 --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using B0DataType = BF16; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = BF16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: bf16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, bf16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: bf16 in, bf16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ce91a8c6023314d61b354614c81ca74c5f95727 --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fd2bf69306f5b35dd4ea2ae484c677c6b3c63fc --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using B0DataType = BF16; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = BF16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + false>; + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4a8589052f0f149902553884a58cc4c6c79d030 --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + false>; + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4a71b04313446f75f4cbc73456a66bdbb62c47e --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_grouped_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..38b5badc6e4f270e409ed98c492035c5bebe00c3 --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_grouped_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc new file mode 100644 index 0000000000000000000000000000000000000000..4e43dbdd8fc5a83486cd052ee3e789e05152861f --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1020; + ck::index_t N = 1020; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideC = -1; + ck::index_t BatchStrideA = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideC = -1; + float alpha = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + } + else if(argc == 18) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + + StrideA = std::stoi(argv[9]); + StrideB0 = std::stoi(argv[10]); + StrideB1 = std::stoi(argv[11]); + StrideC = std::stoi(argv[12]); + + BatchStrideA = std::stoi(argv[13]); + BatchStrideB0 = std::stoi(argv[14]); + BatchStrideB1 = std::stoi(argv[15]); + BatchStrideC = std::stoi(argv[16]); + + alpha = std::stof(argv[17]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " + "BatchStrideB0, BatchStrideB1, BatchStrideC\n"); + printf("arg17: scale (alpha)\n"); + exit(0); + } + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * + c_g_m_o_device_result.mDesc.GetElementSpaceSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + if(do_verification) + { + // Output of Gemm0 is input A of Gemm1 + Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc new file mode 100644 index 0000000000000000000000000000000000000000..0b876af952f65f89e9ae55bea344bfa6f61afd5d --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + // TODO ANT: replace array with vector? + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + {}, // std::array p_acc0_biases; + {}, // std::array p_acc1_biases; + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_lengths}, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_strides}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_lengths}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_strides}, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + ck::index_t BatchCount = G0 * G1; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = DeviceGemmInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + return ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol) + ? 0 + : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc new file mode 100644 index 0000000000000000000000000000000000000000..ef2acf61f55fa7f108e174d4586c13ae3dc7419d --- /dev/null +++ b/3rdparty/composable_kernel/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + input_permute = std::stoi(argv[4]); + output_permute = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 5: input / output permute\n"); + exit(0); + } + + float alpha = 1; // scaling after 1st gemm + + std::size_t group_count = 7; + + // Problem descs + std::vector problem_descs; + std::vector p_a; + std::vector p_b0; + std::vector p_b1; + std::vector p_c; + std::vector> g0_g1_m_n_k_o; + + std::vector> a_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> c_tensors; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_tensors_device; + std::vector b0_tensors_device; + std::vector b1_tensors_device; + std::vector c_tensors_device; + + std::size_t flop = 0, num_byte = 0; + + std::cout << "group count " << group_count << ". printing first 4 groups\n"; + for(std::size_t i = 0; i < group_count; i++) + { + int M = 128 * (rand() % 8 + 1); + int N = 128 * (rand() % 8 + 1); + int K = 40; + int O = 40 * (rand() % 2 + 1); + int G0 = rand() % 3 + 1; + int G1 = rand() % 5 + 1; + + g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + problem_descs.push_back({a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}}); // acc1_biases_gs_ms_os_strides + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + int Batch = G0 * G1; + flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + Batch; + + if(i < 4) + { + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " + << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " + << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " + << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + } + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + a_tensors.push_back(a_gs_ms_ks); + b0_tensors.push_back(b0_gs_ns_ks); + b1_tensors.push_back(b1_gs_os_ns); + c_tensors.push_back(c_gs_ms_os_device_result); + + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize())); + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize())); + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data()); + b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); + b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); + p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(p_a, + p_b0, + p_b1, + p_c, + {}, // p_acc0_biases + {}, // p_acc1_biases + problem_descs, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + // specify workspace for problem_desc + DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + for(std::size_t i = 0; i < group_count; i++) + { + const int& G0 = g0_g1_m_n_k_o[i][0]; + const int& G1 = g0_g1_m_n_k_o[i][1]; + const int& M = g0_g1_m_n_k_o[i][2]; + const int& N = g0_g1_m_n_k_o[i][3]; + const int& K = g0_g1_m_n_k_o[i][4]; + const int& O = g0_g1_m_n_k_o[i][5]; + + const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; + const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; + + const auto& a_gs_ms_ks = a_tensors[i]; + const auto& b0_gs_ns_ks = b0_tensors[i]; + const auto& b1_gs_os_ns = b1_tensors[i]; + auto& c_gs_ms_os_device_result = c_tensors[i]; + auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; + + c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({G0 * G1, M, K}); + Tensor b0_g_k_n({G0 * G1, K, N}); + Tensor b1_g_n_o({G0 * G1, N, O}); + Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = DeviceGemmInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm 1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + bool pass_ = + ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); + pass &= pass_; + } + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/33_multiple_reduce/CMakeLists.txt b/3rdparty/composable_kernel/example/33_multiple_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc8c3eb04e399868db98a42604fa4b861a638fd9 --- /dev/null +++ b/3rdparty/composable_kernel/example/33_multiple_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_dual_reduce_multiblock dual_reduce_multiblock.cpp) +add_example_executable(example_dual_reduce_threadwise dual_reduce_threadwise.cpp) diff --git a/3rdparty/composable_kernel/example/33_multiple_reduce/README.md b/3rdparty/composable_kernel/example/33_multiple_reduce/README.md new file mode 100644 index 0000000000000000000000000000000000000000..90762a692fcce6f612c5094f6f7042483d0329aa --- /dev/null +++ b/3rdparty/composable_kernel/example/33_multiple_reduce/README.md @@ -0,0 +1,37 @@ +# Instructions for ```example_dual_reduce``` + +## Run ```example_dual_reduce_multiblock``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1 +``` + +Result +``` +./bin/example_dual_reduce_multiblock -D 600,28,28,256 -v 1 2 1 +launch_and_time_kernel: grid_dim {150, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.19529 ms, 201.499 GB/s, DeviceMultipleReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_1_InSrcVectorSize_1,OutDstVectorSize_1_1> +``` + +## Run ```example_dual_reduce_threadwise``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +./bin/example_dual_reduce_multiblock -D 8000,4,4,4 -v 1 2 1 +``` + +Result +``` +./bin/example_dual_reduce_threadwise -D 8000,4,4,4 -v 1 2 1 +launch_and_time_kernel: grid_dim {32, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.01512 ms, 71.9577 GB/s, DeviceMultipleReduceThreadwise<256,M_C256_S1,K_C1_S4,InSrcVectorDim_1_InSrcVectorSize_2,OutDstVectorSize_1_1> +``` diff --git a/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_common.hpp b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..376b95ea7b307370a8e3f08f228243eebdf4d40f --- /dev/null +++ b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_common.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {600, 28, 28, 256}; + size_t n, h, w, c; + + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; + + public: + SimpleAppArgs() + { + n = inLengths[0]; + h = inLengths[1]; + w = inLengths[2]; + c = inLengths[3]; + }; + + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + if(inLengths.size() != 4) + throw std::runtime_error( + "Invalid option format! The number of integers is incorrect!"); + + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + n = inLengths[0]; + h = inLengths[1]; + w = inLengths[2]; + c = inLengths[3]; + + return (0); + }; +}; + +template +static void mean_meansquare_host(const Tensor& in, + Tensor& mean_ref, + Tensor& meansquare_ref, + size_t n, + size_t h, + size_t w, + size_t c) + +{ + auto thread_reduce_func = [&](auto iN) { + AccDataType mean = ck::type_convert(0.0f); + AccDataType meansquare = ck::type_convert(0.0f); + + // compute mean, meanquare, variance, invVariance + for(std::size_t iH = 0; iH < h; iH++) + { + for(std::size_t iW = 0; iW < w; iW++) + { + for(std::size_t iC = 0; iC < c; iC++) + { + AccDataType curr_value = ck::type_convert(in(iN, iH, iW, iC)); + + mean += curr_value; + meansquare += curr_value * curr_value; + }; + } + }; + + mean = mean / (h * w * c); + meansquare = meansquare / (h * w * c); + + mean_ref(iN) = ck::type_convert(mean); + meansquare_ref(iN) = ck::type_convert(meansquare); + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t work_per_thread = (n + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; it++) + { + std::size_t iN_begin = it * work_per_thread; + std::size_t iN_end = std::min(static_cast((it + 1) * work_per_thread), n); + + auto f = [=] { + for(std::size_t iN = iN_begin; iN < iN_end; iN++) + { + thread_reduce_func(iN); + } + }; + + threads[it] = joinable_thread(f); + } +}; + +using ReduceOperation = ck::reduce::Add; + +using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough; +using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide; + +using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare; +using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide; + +using InElementwiseOperationTuple = + ck::Tuple; +using AccElementwiseOperationTuple = + ck::Tuple; + +template +int mean_meansquare_dual_reduce_test(size_t n, + size_t h, + size_t w, + size_t c, + bool do_verification, + int init_method, + bool time_kernel, + const std::array reduceDims) +{ + const std::vector inLengths = {n, h, w, c}; + + Tensor in(inLengths); + + std::vector outLengths{n}; + + Tensor mean_ref(outLengths); + Tensor mean(outLengths); + Tensor meansquare_ref(outLengths); + Tensor meansquare(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = mean.mDesc.GetStrides(); + + size_t invariant_total_length = n; + size_t reduce_total_length = h * w * c; + + const AccDataType alpha = ck::type_convert(1.0f); + const AccDataType beta = ck::type_convert(0.0f); + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); break; + case 2: in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(OutDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem meansquare_dev(sizeof(OutDataType) * meansquare.mDesc.GetElementSpaceSize()); + + in_dev.ToDevice(in.mData.data()); + + if(do_verification) + { + mean_meansquare_host( + in, mean_ref, meansquare_ref, n, h, w, c); + }; + + constexpr ck::index_t NumInputDim = Rank; + constexpr ck::index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1; + + std::array i_inLengths; + std::array i_inStrides; + std::array i_outLengths; + std::array i_outStrides; + + ck::ranges::copy(inLengths, i_inLengths.begin()); + ck::ranges::copy(inStrides, i_inStrides.begin()); + ck::ranges::copy(outLengths, i_outLengths.begin()); + ck::ranges::copy(outStrides, i_outStrides.begin()); + + auto dual_reduce_op = DeviceDualReduce{}; + + auto argument_ptr = dual_reduce_op.MakeArgumentPointer( + i_inLengths, + i_inStrides, + i_outLengths, + {i_outStrides, i_outStrides}, + reduceDims, + {&alpha, &alpha}, + {&beta, &beta}, + in_dev.GetDeviceBuffer(), + {mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()}, + ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}), + ck::make_tuple( + AccElementwiseOperation_Mean{static_cast(reduce_total_length)}, + AccElementwiseOperation_Meansquare{static_cast(reduce_total_length)})); + + if(!dual_reduce_op.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + return (-1); + }; + + std::string reduce_name = dual_reduce_op.GetTypeString(); + + auto invoker_ptr = dual_reduce_op.MakeInvokerPointer(); + + float avg_time = 0.0f; + + avg_time += invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + + 2 * invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(do_verification) + { + mean_dev.FromDevice(mean.mData.data()); + meansquare_dev.FromDevice(meansquare.mData.data()); + pass = pass && ck::utils::check_err(mean, mean_ref); + pass = pass && ck::utils::check_err(meansquare, meansquare_ref); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_multiblock.cpp b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_multiblock.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9360599ed9e8a7517bc244e2942d6949c1f71e61 --- /dev/null +++ b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_multiblock.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "dual_reduce_common.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = float; +using OutDataTypeTuple = Tuple; +using AccDataType = float; + +// for NHWC layer-norm calculation of mean and meansquare +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +constexpr bool PropagateNan = false; + +constexpr InMemoryDataOperationEnum OutMemoryDataOperation = InMemoryDataOperationEnum::Set; + +using DeviceDualReduce = DeviceMultipleReduceMultiBlock<2, + InDataType, + AccDataType, + OutDataTypeTuple, + Rank, + NumReduceDim, + ReduceOperation, + InElementwiseOperationTuple, + AccElementwiseOperationTuple, + OutMemoryDataOperation, + PropagateNan, + 256, + 4, + 64, + 1, + 1, + 1, // InSrcVectorDim + 1, + ck::Sequence<1, 1>>; + +int main(int argc, char* argv[]) +{ + int retval = 0; + + if(argc > 1) + { + SimpleAppArgs arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + std::array reduceDims = {1, 2, 3}; + + retval = mean_meansquare_dual_reduce_test(arg.n, + arg.h, + arg.w, + arg.c, + arg.do_verification, + arg.init_method, + arg.time_kernel, + reduceDims); + } + else + { + std::array reduceDims = {1, 2, 3}; + + retval = mean_meansquare_dual_reduce_test( + 600, 28, 28, 256, true, 2, true, reduceDims); + }; + + return (retval); +} diff --git a/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_threadwise.cpp b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_threadwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56255839e567fa0386ba09b80738f5d0d2abcc23 --- /dev/null +++ b/3rdparty/composable_kernel/example/33_multiple_reduce/dual_reduce_threadwise.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "dual_reduce_common.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = float; +using OutDataTypeTuple = Tuple; +using AccDataType = float; + +// for NHWC layer-norm calculation of mean and meansquare +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +constexpr bool PropagateNan = false; + +using DeviceDualReduce = DeviceMultipleReduceThreadWise<2, + InDataType, + AccDataType, + OutDataTypeTuple, + Rank, + NumReduceDim, + ReduceOperation, + InElementwiseOperationTuple, + AccElementwiseOperationTuple, + PropagateNan, + 256, + 1, + 4, + 1, // InSrcVectorDim + 2, + ck::Sequence<1, 1>>; + +int main(int argc, char* argv[]) +{ + int retval = 0; + + if(argc > 1) + { + SimpleAppArgs arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + std::array reduceDims = {1, 2, 3}; + + retval = mean_meansquare_dual_reduce_test(arg.n, + arg.h, + arg.w, + arg.c, + arg.do_verification, + arg.init_method, + arg.time_kernel, + reduceDims); + } + else + { + std::array reduceDims = {1, 2, 3}; + + retval = mean_meansquare_dual_reduce_test( + 8000, 4, 4, 4, true, 2, true, reduceDims); + }; + + return (retval); +} diff --git a/3rdparty/composable_kernel/example/34_batchnorm/CMakeLists.txt b/3rdparty/composable_kernel/example/34_batchnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d964f40d877ddb5dc6e03cf8d28897e244514d96 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) +add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) +add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp) diff --git a/3rdparty/composable_kernel/example/34_batchnorm/README.md b/3rdparty/composable_kernel/example/34_batchnorm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..294e32b998c7d4e258105672470c930161e85648 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/README.md @@ -0,0 +1,81 @@ +# Instructions for ```batchnorm nhwc``` Example + +## Run ```batchnorm forward nhwc``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64) +#arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes) +#arg3: 1/0 to indicate whether to save result mean/invVariance (0=no, 1=yes) +#arg4: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg5: time kernel (0=no, 1=yes) +./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 0 1 2 1 +``` + +Result +``` +./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 0 1 2 1 +launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 2.08231 ms, 354.519 GB/s +``` + +Result +``` +./bin/example_batchnorm_forward -D 128,16,16,1024 -v 1 0 1 0 2 0 +echo $? +0 +``` + +## Run ```batchnorm infer nhwc``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_batchnorm_infer -D 128,16,16,1024 -v 1 0 2 1 +``` + +Result +``` +./bin/example_batchnorm_infer -D 128,16,16,1024 -v 1 0 2 1 +launch_and_time_kernel: grid_dim {120, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.28235 ms, 523.329 GB/s +``` + +## Run ```batchnorm backward nhwc``` +```bash +# -D : input 4-d tensor lengths +# -v : verification (0=no, 1=yes) +Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64) +Arg2 -- 1/0 to indicate whether to use saved mean and invVariance +Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +Arg4 -- time kernel (0=no, 1=yes) +Arg5: use multi-block welford (0=n0, 1=yes) +./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1 +``` + +Result +``` +./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1 +launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.411026 ms, 91.8702 GB/s +``` diff --git a/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_backward_nhwc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6ca9d150bd918966bae06a63ee4eb9da6be5f3f --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" + +static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchNormBwdArg +{ + private: + int option_index = 0; + + public: + std::vector inOutLengths; + + bool do_verification = false; + + bool haveSavedMeanInvVar; + + int data_type = 0; + int init_method = 3; + bool time_kernel = false; + bool use_multiblock_welford = false; + + public: + void show_usage(const char* cmd) + { + // clang-format off + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl; + std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl; + std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl; + std::cout << "Arg5: use multi-block welford (0=n0, 1=yes)" << std::endl; + // clang-format on + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inOutLengths = getTypeValuesFromString(optarg); + + if(inOutLengths.size() != 4) + throw std::runtime_error( + "NHWC tensor layout should have 4 length values specified!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 5 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + haveSavedMeanInvVar = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + use_multiblock_welford = static_cast(std::atoi(argv[optind])); + + return (0); + }; +}; + +using namespace ck; + +template +bool bnorm_bwd_nhwc_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector inOutLengths, + bool haveSavedMeanInvVar, + double epsilon) +{ + // for NHWC BatchNorm calculation of mean and meansquare + constexpr index_t Rank = 4; + constexpr index_t NumReduceDim = 3; + + using ScaleDataType = XDataType; + + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; + + // input data of the batchnorm backward algorithm + Tensor x(inOutLengths); + Tensor dy(inOutLengths); + + Tensor bnScale(scaleBiasMeanVarLengths); + + Tensor savedMean(scaleBiasMeanVarLengths); + Tensor savedInvVar(scaleBiasMeanVarLengths); + // savedVariance is only used for initializing savedInvVar + Tensor savedVariance(scaleBiasMeanVarLengths); + + // output data of the batchnorm backward algorithm + Tensor dx_ref(inOutLengths); + Tensor dx(inOutLengths); + + Tensor dscale(scaleBiasMeanVarLengths); + Tensor dbias(scaleBiasMeanVarLengths); + + Tensor dscale_ref(scaleBiasMeanVarLengths); + Tensor dbias_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = dy.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = dscale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(haveSavedMeanInvVar) + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.0001f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the savedMean to be values with tiny variation to the mean of the x values + savedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, + num_thread); + + // initialize the variance to be values with tiny variation to the variance of the x values + savedVariance.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + + auto it_src = savedVariance.mData.begin(); + auto it_dst = savedInvVar.mData.begin(); + float tmp_epsilon = std::numeric_limits::epsilon(); + + while(it_src != savedVariance.mData.end()) + { + *it_dst = type_convert( + 1.0f / std::sqrtf(type_convert(*it_src) + tmp_epsilon)); + + it_src++; + it_dst++; + }; + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + dy.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{-0.2f, 0.2f}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_3{-0.5f, 0.5f}, num_thread); + } + }; + + // input data of the batchnorm backward algorithm + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize()); + + DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize()); + + DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize()); + DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize()); + + // output data of the batchnorm backward algorithm + DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize()); + + DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize()); + DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + dy_dev.ToDevice(dy.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + + if(haveSavedMeanInvVar) + { + savedMean_dev.ToDevice(savedMean.mData.data()); + savedInvVar_dev.ToDevice(savedInvVar.mData.data()); + }; + + std::array i_inOutLengths; + std::array i_inOutStrides; + std::array i_scaleBiasMeanVarLengths; + std::array i_scaleBiasMeanVarStrides; + + std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin()); + std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin()); + std::copy(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + i_scaleBiasMeanVarLengths.begin()); + std::copy(scaleBiasMeanVarStrides.begin(), + scaleBiasMeanVarStrides.end(), + i_scaleBiasMeanVarStrides.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceBatchNormBwdInstance = + ck::tensor_operation::device::DeviceBatchNormBwdImpl; // MeanVarSrcVectorSize + + auto batchnorm_bwd = DeviceBatchNormBwdInstance{}; + + auto argument_ptr = batchnorm_bwd.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr, + haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr, + epsilon, + PassThroughOp{}, + dx_dev.GetDeviceBuffer(), + dscale_dev.GetDeviceBuffer(), + dbias_dev.GetDeviceBuffer()); + + if(!batchnorm_bwd.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, " + "exiting!" + << std::endl; + return (false); + }; + + size_t workspace_sz = batchnorm_bwd.GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + batchnorm_bwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = batchnorm_bwd.MakeInvokerPointer(); + + if(time_kernel) + { + float avg_time = 0.0f; + size_t num_bytes = 0; + + size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3]; + size_t invariant_length = inOutLengths[3]; + + avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // inputing of x, dy, scale, outputing of dx, dscale, dbias + num_bytes += + total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3; + + // outputing of mean, inv-variance + num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + (void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + + if(do_verification) + { + using ReferenceBatchNormBwdInstance = + ck::tensor_operation::host::ReferenceBatchNormBwd; + + auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; + + auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x.mData.data(), + dy.mData.data(), + bnScale.mData.data(), + haveSavedMeanInvVar ? savedMean.mData.data() : nullptr, + haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr, + epsilon, + PassThroughOp{}, + dx_ref.mData.data(), + dscale_ref.mData.data(), + dbias_ref.mData.data()); + + if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout + << "The runtime parameters seems not supported by the device instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + + dx_dev.FromDevice(dx.mData.data()); + dscale_dev.FromDevice(dscale.data()); + dbias_dev.FromDevice(dbias.data()); + + // clang-format off + pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4); + pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4); + pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:"); + // clang-format on + }; + + return (pass); +}; + +static const double epsilon = std::numeric_limits::epsilon(); + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + BatchNormBwdArg arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + if(arg.use_multiblock_welford) + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + else + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + } + else if(arg.data_type == 1) + { + if(arg.use_multiblock_welford) + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + else + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + } + else if(arg.data_type == 5) + { + if(arg.use_multiblock_welford) + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + else + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + } + else if(arg.data_type == 6) + { + if(arg.use_multiblock_welford) + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + else + pass = bnorm_bwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.haveSavedMeanInvVar, + epsilon); + } + } + else + { + pass = bnorm_bwd_nhwc_test(true, + 3, + false, // don't time kernel + {128, 16, 6, 512}, + false, + epsilon); + + pass = pass && bnorm_bwd_nhwc_test(true, + 3, + false, // don't time kernel + {128, 16, 3, 1024}, + false, + epsilon); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_common.hpp b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bdbc8ea8b88f40de2f25a2ec8c5a74ab5e38fd74 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_common.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" + +struct NormalizeInInfer +{ + NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {} + + template + __host__ __device__ constexpr void operator()(T1& y, + const T1& x, + const T2& mean, + const T2& variance, + const T3& gamma, + const T4& beta) const + { + static_assert(std::is_same::value || std::is_same::value, + "Data type is not supported by this operation!"); + + using ck::type_convert; + using ck::math::sqrt; + + T2 tmp_x, tmp_y; + + tmp_x = type_convert(x); + + tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * + type_convert(gamma) + + type_convert(beta); + y = type_convert(tmp_y); + }; + + double epsilon_; +}; + +template +static inline std::array +get_invariant_dims(const std::array& reduceDims) +{ + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + std::array invariantDims; + + // collect invariant dimensions + int dim = 0; + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + invariantDims[dim] = i; + dim++; + }; + + return invariantDims; +}; diff --git a/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc2984851a02e0020c8c4cfa6d59e61bfe49c593 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp" + +#include "batchnorm_infer_impl.hpp" + +static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchNormInferArg +{ + private: + int option_index = 0; + + public: + std::vector inOutLengths; + + bool do_verification = false; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension " + "lengths, must have 4 integers for nhwc" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization " + "result by " + "comparing with the host-based batch-normalization" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2: init method used for bnScale and bnBias (0=no init, 1=single integer " + "value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg3: time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inOutLengths = getTypeValuesFromString(optarg); + + if(inOutLengths.size() != 4) + throw std::runtime_error( + "NHWC tensor layout should have 4 length values specified!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 3 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +using namespace ck; + +template +bool bnorm_infer_nhwc_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector inOutLengths, + double epsilon) +{ + // for NHWC BatchNorm calculation of mean and meansquare + constexpr int Rank = 4; + constexpr int NumReduceDim = 3; + + // when using lengths[] to create a tensor, lengths[0] is the length of highest dimension + // eg. N of NHWC, so lengths[3] is the dimension C length of NHWC + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; + + // input data of the batchnorm forward algorithm + Tensor x(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnBias(scaleBiasMeanVarLengths); + + // output data of the batchnorm forward algorithm + Tensor y_ref(inOutLengths); + Tensor y(inOutLengths); + + Tensor estimatedMean(scaleBiasMeanVarLengths); + Tensor estimatedVariance(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if constexpr(std::is_same::value) + { + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + + const float x_mean = 0.0f; + const float x_stddev = 2.5f; + const float noise_stddev = 0.0001f; + + estimatedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, + num_thread); + + estimatedVariance.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.0001f; + + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the savedMean to be values with tiny variation to the mean of the x values + estimatedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, + num_thread); + + // initialize the variance to be values with tiny variation to the variance of the x values + estimatedVariance.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_1{0}, num_thread); + break; + case 2: + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + bnScale.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); + + // mean_dev or resultSaveMean_dev + DeviceMem estimatedMean_dev(sizeof(AccDataType) * estimatedMean.mDesc.GetElementSpaceSize()); + // meansquare_dev or resultSaveInvVariance_dev + DeviceMem estimatedVariance_dev(sizeof(AccDataType) * + estimatedVariance.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + bnBias_dev.ToDevice(bnBias.mData.data()); + estimatedMean_dev.ToDevice(estimatedMean.mData.data()); + estimatedVariance_dev.ToDevice(estimatedVariance.mData.data()); + + using ck::index_t; + + std::array i_inOutLengths; + std::array i_inOutStrides; + std::array i_scaleBiasMeanVarLengths; + std::array i_scaleBiasMeanVarStrides; + + ck::ranges::copy(inOutLengths, i_inOutLengths.begin()); + ck::ranges::copy(inOutStrides, i_inOutStrides.begin()); + ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin()); + ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin()); + + int result = 0; + + result = bnorm_infer(time_kernel, + {0, 1, 2}, + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + estimatedMean_dev.GetDeviceBuffer(), + estimatedVariance_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer()); + + if(result < 0) + return (false); + + bool pass = true; + + if(do_verification) + { + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + using ReferenceBatchNormInferInstance = + ck::tensor_operation::host::ReferenceBatchNormInfer; + auto batchNormInfer_ref = ReferenceBatchNormInferInstance{}; + + auto argument_ptr_ref = + batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x.mData.data(), + bnScale.mData.data(), + bnBias.mData.data(), + epsilon, + PassThroughOp{}, + estimatedMean.mData.data(), + estimatedVariance.mData.data(), + y_ref.mData.data()); + + if(!batchNormInfer_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout + << "The runtime parameters seems not supported by the BatchNorm instance, exiting!" + << std::endl; + return (-2); + }; + + auto invoker_ptr_ref = batchNormInfer_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + + y_dev.FromDevice(y.mData.data()); + pass = pass && ck::utils::check_err(y, y_ref); + }; + + return (pass); +}; + +static const double epsilon = std::numeric_limits::epsilon(); + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + BatchNormInferArg arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + pass = bnorm_infer_nhwc_test( + arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon); + } + else if(arg.data_type == 1) + { + pass = bnorm_infer_nhwc_test( + arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon); + } + else if(arg.data_type == 3) + { + pass = bnorm_infer_nhwc_test( + arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon); + } + else if(arg.data_type == 5) + { + pass = bnorm_infer_nhwc_test( + arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon); + } + else if(arg.data_type == 6) + { + pass = bnorm_infer_nhwc_test( + arg.do_verification, arg.init_method, arg.time_kernel, arg.inOutLengths, epsilon); + }; + } + else + { + pass = bnorm_infer_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 16, 1024}, + epsilon); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da36d65a2954ee1ac3a94c617a219e6a7c2baf44 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -0,0 +1,591 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchNormFwdArg +{ + private: + int option_index = 0; + + public: + std::vector inOutLengths; + + bool do_verification = false; + + bool updateMovingAverage; + bool saveMeanAndInvVariance; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + bool use_multiblock_welford = false; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension " + "lengths, must have 4 integers for nhwc" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization " + "result by " + "comparing with the host-based batch-normalization" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer " + "value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inOutLengths = getTypeValuesFromString(optarg); + + if(inOutLengths.size() != 4) + throw std::runtime_error( + "NHWC tensor layout should have 4 length values specified!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 6 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + updateMovingAverage = std::atoi(argv[optind++]); + saveMeanAndInvVariance = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + use_multiblock_welford = static_cast(std::atoi(argv[optind])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +using namespace ck; + +template +bool bnorm_fwd_nhwc_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector inOutLengths, + bool updateMovingAverage, + bool saveMeanAndInvVariance, + double averageFactor, + double epsilon) +{ + // for NHWC BatchNorm calculation of mean and meansquare + constexpr int Rank = 4; + constexpr int NumReduceDim = 3; + + // when using lengths[] to create a tensor, lengths[0] is the length of highest dimension + // eg. N of NHWC, so lengths[3] is the dimension C length of NHWC + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; + + // input data of the batchnorm forward algorithm + Tensor x(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnBias(scaleBiasMeanVarLengths); + + // output data of the batchnorm forward algorithm + Tensor y_ref(inOutLengths); + Tensor y(inOutLengths); + + Tensor resultSaveMean_ref(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance_ref(scaleBiasMeanVarLengths); + + Tensor resultRunningMean_ref(scaleBiasMeanVarLengths); + Tensor resultRunningVariance_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(updateMovingAverage) + { + if constexpr(std::is_same::value) + { + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + + const float x_mean = 0.0f; + const float x_stddev = 2.5f; + const float noise_stddev = 0.04f; + + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.04f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the runningMean to be values with tiny variation to the mean of the x + // values + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + // initialize the runningVariance to be values with tiny variation to the variance of + // the x values + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + }; + } + else + { + if constexpr(std::is_same::value) + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + else + x.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_1{0}, num_thread); + break; + case 2: + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + bnScale.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); + + // mean_dev or resultSaveMean_dev + DeviceMem resultSaveMean_dev(sizeof(AccDataType) * + resultSaveMean_ref.mDesc.GetElementSpaceSize()); + // meansquare_dev or resultSaveInvVariance_dev + DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) * + resultSaveInvVariance_ref.mDesc.GetElementSpaceSize()); + // resultRunningMean_dev + DeviceMem resultRunningMean_dev(sizeof(AccDataType) * + resultRunningMean_ref.mDesc.GetElementSpaceSize()); + // resultRunningVariance_dev + DeviceMem resultRunningVariance_dev(sizeof(AccDataType) * + resultRunningVariance_ref.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + bnBias_dev.ToDevice(bnBias.mData.data()); + + if(updateMovingAverage) + { + resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data()); + resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data()); + }; + + std::array i_inOutLengths; + std::array i_inOutStrides; + std::array i_scaleBiasMeanVarLengths; + std::array i_scaleBiasMeanVarStrides; + + ck::ranges::copy(inOutLengths, i_inOutLengths.begin()); + ck::ranges::copy(inOutStrides, i_inOutStrides.begin()); + ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin()); + ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceBatchNormFwdInstance = + ck::tensor_operation::device::DeviceBatchNormFwdImpl; + + auto batchnorm_fwd = DeviceBatchNormFwdInstance{}; + + auto argument_ptr = batchnorm_fwd.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + PassThroughOp{}, + y_dev.GetDeviceBuffer(), + saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, + updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); + + if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, " + "exiting!" + << std::endl; + return (false); + }; + + size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer(); + + if(time_kernel) + { + float avg_time = 0.0f; + size_t num_bytes = 0; + + size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3]; + size_t invariant_length = inOutLengths[3]; + + avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // inputing of x, scale, bias, outputing of y + num_bytes += + total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2; + + // outputing of mean, inv-variance + num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0; + + // updating of moving mean, variance + num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + (void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + + if(do_verification) + { + + using ReferenceBatchNormFwdInstance = + ck::tensor_operation::host::ReferenceBatchNormFwd; + + auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; + + auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x.mData.data(), + bnScale.mData.data(), + bnBias.mData.data(), + epsilon, + PassThroughOp{}, + y_ref.mData.data(), + saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, + updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); + + if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm reference " + "instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + + y_dev.FromDevice(y.mData.data()); + pass = pass && ck::utils::check_err(y, y_ref); + + if(updateMovingAverage) + { + Tensor resultRunningMean(scaleBiasMeanVarLengths); + Tensor resultRunningVariance(scaleBiasMeanVarLengths); + + resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); + resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); + + pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref); + pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref); + }; + + if(saveMeanAndInvVariance) + { + using ck::host_common::dumpBufferToFile; + + Tensor resultSaveMean(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance(scaleBiasMeanVarLengths); + + resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); + resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); + + pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref); + pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref); + }; + }; + + return (pass); +}; + +const double epsilon = std::numeric_limits::epsilon(); +static const double averageFactor = 0.1; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + BatchNormFwdArg arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 1) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 3) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 5) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 6) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + } + else + { + pass = bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 6, 512}, + true, + true, + averageFactor, + epsilon); + + pass = pass && bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 3, 1024}, + true, + true, + averageFactor, + epsilon); + }; + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_infer_impl.hpp b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_infer_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e457df81da2087ebf04b0e676964ec857b91ee77 --- /dev/null +++ b/3rdparty/composable_kernel/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" + +#include "batchnorm_common.hpp" + +template +int bnorm_infer( + bool time_kernel, + const std::array reduceDims, + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const void* p_estimatedMean, + const void* p_estimatedVariance, + void* p_y) +{ + (void)bnScaleBiasMeanVarLengths; + + static_assert(NumBatchNormReduceDim < Rank, + "Invalid number of reduced dimensions for batchnorm!"); + + using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, // x, mean, + // variance, + // scale, + // bias, + ck::Tuple, // y + NormalizeInInfer, + Rank, + 2, // MPerthread + ck::Sequence<1, 1, 1, 1, 1>, // x, mean, variance, scale, bias + ck::Sequence<1>>; // scalarPerVector: y + + auto invariantDims = get_invariant_dims(reduceDims); + std::array aligned_bnScaleStrides{0}; + std::array aligned_bnBiasStrides{0}; + std::array aligned_bnMeanVarStrides{0}; + + int i = 0; + for(auto dim : invariantDims) + { + assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); + + aligned_bnScaleStrides[dim] = bnScaleStrides[i]; + aligned_bnBiasStrides[dim] = bnBiasStrides[i]; + aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i]; + i++; + }; + + int32_t reduceLength = 1; + + for(auto dim : reduceDims) + reduceLength *= xyLengths[dim]; + + int32_t invariantLength = 1; + + for(auto dim : invariantDims) + invariantLength *= xyLengths[dim]; + + size_t total_length = static_cast(invariantLength) * reduceLength; + + float avg_time = 0.0f; + std::size_t num_bytes = 0; + + auto dev_normalize = DeviceNormalizeInstance{}; + + auto argument_ptr1 = dev_normalize.MakeArgumentPointer( + xyLengths, + {xStrides, + aligned_bnMeanVarStrides, + aligned_bnMeanVarStrides, + aligned_bnScaleStrides, + aligned_bnBiasStrides}, + {yStrides}, + {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, + {p_y}, + NormalizeInInfer{epsilon}); + + if(!dev_normalize.IsSupportedArgument(argument_ptr1.get())) + { + std::cout << "The runtime parameters seems not supported by the Devic, exiting!" + << std::endl; + + return (-1); + }; + + auto invoker_ptr1 = dev_normalize.MakeInvokerPointer(); + + avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); + + num_bytes += total_length * sizeof(XDataType) + + invariantLength * + (sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) + + total_length * sizeof(YDataType); + + if(time_kernel) + { + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + }; + + return (0); +}; diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/CMakeLists.txt b/3rdparty/composable_kernel/example/35_splitK_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..794583954674f403a4902a6fc1cb244b803479d5 --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/CMakeLists.txt @@ -0,0 +1,17 @@ +add_custom_target(example_splitK_gemm_xdl) + +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + +add_dependencies(example_splitK_gemm_xdl + example_splitK_gemm_xdl_fp32 + example_splitK_gemm_xdl_fp16 + example_splitK_gemm_xdl_bfp16 + example_splitK_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/run_splitK_gemm_example.inc b/3rdparty/composable_kernel/example/35_splitK_gemm/run_splitK_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..e9bd5c552dea937deb78859b94b731eec6ec03b0 --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -0,0 +1,217 @@ +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t k_batch = 4; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); +#endif + + auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + c_m_n_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), +#endif + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(config.do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + if(std::is_same::value) + { + pass &= ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "fp16 incorrect result", 3e-3, 1e-3); + } + else + { + pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_splitK_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.k_batch = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.k_batch = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.stride_A = std::stoi(argv[8]); + problem_size.stride_B = std::stoi(argv[9]); + problem_size.stride_C = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: KBatch\n"); + printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + return run_splitK_gemm(problem_size, config); +} diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7191ecf50ab5252b37855072b4e03f1194df52f6 --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efdb315b4e5576aef8d992f6d57c231ac1f83a81 --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc2e3d1d52b668a2ea15c9f4ea1db44c4faeb19e --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4eb27824628d8f193ca6108a589794041fc2b2fb --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CDataType = int32_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<0, 2, 1, 3>, // ABlockTransfer ThreadCluster ArrangeOrder + S<0, 2, 1, 3>, // ABlockTransfer SrcAccessOrder + 3, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + true, // ABlockLdsExtraM + S<1, 4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<0, 1, 3, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<0, 1, 3, 2>, // BBlockTransfer SrcAccessOrder + 3, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + true, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CBlockTransferClusterLengths _MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eefdbca6b1ae3e49cc221b3037e46888b4a9a9bd --- /dev/null +++ b/3rdparty/composable_kernel/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CDataType = int32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt b/3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..16ca6ea41112776ac83434c4188b294d6ef027bd --- /dev/null +++ b/3rdparty/composable_kernel/example/36_sparse_embedding/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp) diff --git a/3rdparty/composable_kernel/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/3rdparty/composable_kernel/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5eb4c3b6b0441451e9a6a667321b8d843dc36b8 --- /dev/null +++ b/3rdparty/composable_kernel/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp" + +// using EmbType = float; +// using IndexType = int64_t; +// using GammaDataType = float; +// using BetaDataType = float; +// using AccDataType = float; +// using OutType = float; + +using EmbType = ck::half_t; +using IndexType = int64_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using AccDataType = float; +using OutType = ck::half_t; + +// clang-format off +// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize +using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; + +using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; +using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm; + +template struct emb_kernel{}; + +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e256; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e512; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e768; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1024;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e1536;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e2048;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e4096;}; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp32_e8192;}; +template<> struct emb_kernel{ using kernel_type = DeviceInstance_fp32_e16384;}; + +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e256; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e512; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e768; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e1024; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e1536; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e2048; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e4096; }; +template<> struct emb_kernel { using kernel_type = DeviceInstance_fp16_e8192; }; + +// clang-format on + +int main() +{ + bool time_kernel = true; + + constexpr auto num_rows = 65536; + constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; + // constexpr auto dims = ck::Sequence<256, 512>{}; + constexpr auto index_length = 2048; + constexpr AccDataType epsilon = 1e-4; + + auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); }; + + auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) { + return HostTensorDescriptor({rows_, cols_}); + }; + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceSparseEmbedding3ForwardLayernorm; + + ck::static_for<0, dims.Size(), 1>{}([&](auto I) { + std::srand(std::time(nullptr)); + constexpr auto current_dim = dims.At(I); + Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); + + Tensor index_a(f_host_tensor_desc_1d(index_length)); + Tensor index_b(f_host_tensor_desc_1d(index_length)); + Tensor index_c(f_host_tensor_desc_1d(index_length)); + + Tensor gamma(f_host_tensor_desc_1d(current_dim)); + Tensor beta(f_host_tensor_desc_1d(current_dim)); + + Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); + + emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); + DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); + DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); + + DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); + DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); + DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); + + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + + DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); + + emb_a_dev.ToDevice(emb_a.mData.data()); + emb_b_dev.ToDevice(emb_b.mData.data()); + emb_c_dev.ToDevice(emb_c.mData.data()); + + index_a_dev.ToDevice(index_a.mData.data()); + index_b_dev.ToDevice(index_b.mData.data()); + index_c_dev.ToDevice(index_c.mData.data()); + + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = typename emb_kernel::kernel_type{}; + auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), + emb_a_dev.GetDeviceBuffer(), + emb_b_dev.GetDeviceBuffer(), + emb_c_dev.GetDeviceBuffer(), + index_a_dev.GetDeviceBuffer(), + index_b_dev.GetDeviceBuffer(), + index_c_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + num_rows, + current_dim, + index_length, + epsilon); + std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() + << std::endl + << std::flush; + + bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); + + if(!is_supported) + { + std::cout << "Runtime parameters are not supported" << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(out, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + num_rows, + current_dim, + index_length, + epsilon); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + out_dev.FromDevice(out_from_dev.mData.data()); + pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); + } + + double total_read = current_dim * index_length * 3 * sizeof(EmbType) + + current_dim * sizeof(GammaDataType) + + current_dim * sizeof(BetaDataType); + double total_write = current_dim * index_length * sizeof(OutType); + double gbps = (total_read + total_write) / time_ms / 1e6; + + std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms + << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl + << std::flush; + }); + + return 0; +} diff --git a/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9be3a7108fee6075dd7ae81f7536b072e655c0b --- /dev/null +++ b/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..071e8a7431c057a32ea0061655d81a202774b923 --- /dev/null +++ b/3rdparty/composable_kernel/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -0,0 +1,519 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F16; +using B0DataType = F16; +using Acc0DataType = F32; +using D00DataType = F16; +using D01DataType = F16; +using B1DataType = F16; +using Acc1DataType = F32; +using C1ShuffleDataType = F32; +using D1DataType = F16; +using E1DataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D00Layout = Row; +using D01Layout = Row; +using B1Layout = Row; +using D1Layout = Row; +using E1Layout = Row; + +// E = Relu(C + D0 + D1) +struct AddAddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + } + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + (d0 + d1); + + ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + } +}; + +// E = Gelu(C + D0 + D1) +struct AddAddGelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::Gelu{}.template operator()(e, + x); + } + + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + (d0 + d1); + + ck::tensor_operation::element_wise::Gelu{}.template operator()(e, x); + } +}; + +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu +{ + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + (d0 + d1); + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } +}; + +using A0ElementOp = PassThrough; +using B0ElementOp = PassThrough; +using CDE0ElementOp = AddAddRelu; +using A1ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr bool PadGemm0M = false; +static constexpr bool PadGemm0N = false; +static constexpr bool PadGemm0K = false; +static constexpr bool PadGemm1N = false; +static constexpr bool PadGemm1K = false; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< + A0Layout, + B0Layout, + ck::Tuple, + B1Layout, + ck::Tuple, + E1Layout, + A0DataType, + B0DataType, + Acc0DataType, + ck::Tuple, + B1DataType, + Acc1DataType, + C1ShuffleDataType, + ck::Tuple, + E1DataType, + A0ElementOp, + B0ElementOp, + CDE0ElementOp, + B1ElementOp, + CDE1ElementOp, + PadGemm0M, + PadGemm0N, + PadGemm0K, + PadGemm1N, + PadGemm1K, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA0 = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideD00 = -1; + ck::index_t StrideD01 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideD1 = -1; + ck::index_t StrideE1 = -1; + ck::index_t BatchStrideA0 = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideD00 = -1; + ck::index_t BatchStrideD01 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideD1 = -1; + ck::index_t BatchStrideE1 = -1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + + StrideA0 = std::stoi(argv[9]); + StrideB0 = std::stoi(argv[10]); + StrideD00 = std::stoi(argv[11]); + StrideD01 = std::stoi(argv[12]); + StrideB1 = std::stoi(argv[13]); + StrideD1 = std::stoi(argv[14]); + StrideE1 = std::stoi(argv[15]); + + BatchStrideA0 = std::stoi(argv[16]); + BatchStrideB0 = std::stoi(argv[17]); + BatchStrideD00 = std::stoi(argv[18]); + BatchStrideD01 = std::stoi(argv[19]); + BatchStrideB1 = std::stoi(argv[20]); + BatchStrideD1 = std::stoi(argv[21]); + BatchStrideE1 = std::stoi(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 8: M, N, K, O, Batch\n"); + printf( + "arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n"); + printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, " + "BatchStrideB1, BatchStrideD1, BatchStrideE1 \n"); + exit(0); + } + + const int DefaultStrideA0 = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideD00 = ck::is_same_v ? N : M; + const int DefaultStrideD01 = ck::is_same_v ? N : M; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideD1 = ck::is_same_v ? O : M; + const int DefaultStrideE1 = ck::is_same_v ? O : M; + + StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00; + StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1; + StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1; + + const int DefaultBatchStrideA0 = (ck::is_same_v ? K : M) * StrideA0; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideD00 = (ck::is_same_v ? N : M) * StrideD00; + const int DefaultBatchStrideD01 = (ck::is_same_v ? N : M) * StrideD01; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideD1 = (ck::is_same_v ? O : M) * StrideD1; + const int DefaultBatchStrideE1 = (ck::is_same_v ? O : M) * StrideE1; + + BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00; + BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; + BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + // E_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a0_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor d00_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{})); + Tensor d01_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor d1_g_m_o( + f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{})); + Tensor e1_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + Tensor e1_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + + std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc + << " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl; + std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc + << " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + case 2: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * + e1_g_m_o_device_result.mDesc.GetElementSize()); + DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); + + a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data()); + d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); + + auto a0_element_op = A0ElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto cde0_element_op = CDE0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto cde1_element_op = CDE1ElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a0_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + std::array{d00_g_m_n_device_buf.GetDeviceBuffer(), + d01_g_m_n_device_buf.GetDeviceBuffer()}, + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + std::array{d1_g_m_o_device_buf.GetDeviceBuffer()}, + static_cast(e1_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + std::array{StrideD00, StrideD01}, + StrideB1, + std::array{StrideD1}, + StrideE1, + BatchStrideA0, + BatchStrideB0, + std::array{BatchStrideD00, BatchStrideD01}, + BatchStrideB1, + std::array{BatchStrideD1}, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = + (sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N + + sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + + sizeof(D1DataType) * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); + + if(do_verification) + { + using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + // Output of Gemm0 is input A of Gemm1 + Tensor c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // bias+bias+relu + e0_g_m_n.ForEach([&](auto&, auto idx) { + cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx)); + }); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{}); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // bias + e1_g_m_o_host_result.ForEach([&](auto&, auto idx) { + cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); + }); + + return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cf960c501ccb279febc8205505ca0edfdeefac4 --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -0,0 +1,7 @@ +add_custom_target(example_grouped_conv_bwd_data) + +add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) + +add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) +add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d07ee7bdc1c4b999f61a84727bb10fe7ea6f4c15 --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static inline constexpr ck::index_t NDimSpatial = 2; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +using FP16 = ck::half_t; +using FP32 = float; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 32, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55ea8c3a3109469532d0e5c0566da3eb11e7353d --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasDataType = FP16; // bias +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::Tuple; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = ck::tensor_operation::element_wise::AddRelu; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_conv_bwd_data_bias_relu_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_bias_relu_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddf82ec512c61c72a1a2e2323a6645f39f15ce6d --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..0afd8bd70da849085db6e19583e933c3d1b91968 --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_params, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_n_c_wis_desc, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const OutElementOp& out_element_op, + const WeiElementOp& wei_element_op, + const InElementOp& in_element_op) +{ + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_n_c_wis_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_c_wis_lengths{}; + std::array d0_g_n_c_wis_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_n_c_wis_desc.GetLengths(), d0_g_n_c_wis_lengths); + copy(bias_g_n_c_wis_desc.GetStrides(), d0_g_n_c_wis_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + static_assert(std::is_default_constructible_v); + + // do conv + auto conv = DeviceConvInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer()}, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{d0_g_n_c_wis_lengths}, + std::array, 1>{d0_g_n_c_wis_strides}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_params.GetFlops(); + std::size_t num_btype = conv_params.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(config.do_verification) + { + // c doesn't physically exist, any layout is fine + Tensor c_host(in_g_n_c_wis_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(c_host, + wei, + out, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_, + PassThrough{}, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + in_host.ForEach( + [&](auto&, auto idx) { in_element_op(in_host(idx), c_host(idx), bias(idx)); }); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device, in_host); + } + + return true; +} + +int run_grouped_conv_bwd_data_bias_relu_example(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatials dimensions" << std::endl; + return EXIT_FAILURE; + } + + // output image: GNHWK + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_params); + + // weight: GKYXC + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_params); + + // input image bias: G_C + const auto bias_g_n_c_wis_desc = HostTensorDescriptor({conv_params.G_, + conv_params.N_, + conv_params.C_, + conv_params.input_spatial_lengths_[0], + conv_params.input_spatial_lengths_[1]}, + { + conv_params.C_, // g + 0, // n + 1, // c + 0, // hi + 0 // wi + }); + + // input image: GNHWC + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + return !run_conv_bwd_data_bias_relu(config, + conv_params, + out_g_n_k_wos_desc, + wei_g_k_c_xs_desc, + bias_g_n_c_wis_desc, + in_g_n_c_wis_desc, + wei_element_op, + out_element_op, + in_element_op); +} diff --git a/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..e50c98bbe844f6ec438c781b7d5d8e254e57d98f --- /dev/null +++ b/3rdparty/composable_kernel/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +bool run_conv_bwd_data(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_params, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const OutElementOp& out_element_op, + const WeiElementOp& wei_element_op, + const InElementOp& in_element_op) +{ + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + static_assert(std::is_default_constructible_v); + + // do conv + auto conv = DeviceConvInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_params.GetFlops(); + std::size_t num_btype = conv_params.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(config.do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_, + PassThrough{}, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device.mData, in_host.mData); + } + + return true; +} + +int run_grouped_conv_bwd_data_example(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatials dimensions" << std::endl; + return EXIT_FAILURE; + } + + // output image: GNHWK + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_params); + + // weight: GKYXC + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_params); + + // input image: GNHWC + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + return !run_conv_bwd_data(config, + conv_params, + out_g_n_k_wos_desc, + wei_g_k_c_xs_desc, + in_g_n_c_wis_desc, + wei_element_op, + out_element_op, + in_element_op); +} diff --git a/3rdparty/composable_kernel/example/39_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/39_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..573ad7239e608e7ba947f9ca864a0c7f4da4689b --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(example_permute) + +add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) +add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) +add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) + +add_dependencies(example_permute example_permute_1xHxW_fp16) +add_dependencies(example_permute example_permute_NxHxW_fp16) +add_dependencies(example_permute example_permute_HxWx4_fp16) diff --git a/3rdparty/composable_kernel/example/39_permute/common.hpp b/3rdparty/composable_kernel/example/39_permute/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab612cea1794c422c737a02e5bbe7a6728904c40 --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/common.hpp @@ -0,0 +1,456 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/utility/type.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; + +struct Problem final +{ + static constexpr std::size_t NumDim = 3; + + using Shape = std::array; + using Axes = Shape; + + Problem() = delete; + + explicit Problem(const Shape& default_shape, const Axes& default_axes) + : shape(default_shape), axes(default_axes) + { + } + + Shape shape; + Axes axes; +}; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace detail { + +template +struct enlarge_array_size; + +template +struct enlarge_array_size, Difference> +{ + using type = std::array; +}; + +template +using enlarge_array_size_t = typename enlarge_array_size::type; + +template +struct get_array_size; + +template +struct get_array_size> : std::integral_constant +{ +}; + +template +inline constexpr std::size_t get_array_size_v = get_array_size::value; + +template +struct is_iterator : std::false_type +{ +}; + +template +struct is_iterator()), + decltype(++std::declval>()), + decltype(std::declval>()++)>> + : std::true_type +{ +}; + +template +inline constexpr bool is_iterator_v = is_iterator::value; + +struct Placeholder final +{ + template + constexpr inline operator T() const noexcept; +}; + +template +struct is_output_iterator : std::false_type +{ +}; + +template +struct is_output_iterator< + Iterator, + std::void_t() = std::declval())>> + : std::bool_constant> +{ +}; + +template +inline constexpr bool is_output_iterator_v = is_output_iterator::value; + +template +struct is_bidirectional_iterator : std::false_type +{ +}; + +template +struct is_bidirectional_iterator< + Iterator, + std::void_t>()), + decltype(std::declval>()--)>> + : std::bool_constant> +{ +}; + +template +inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator::value; + +template +struct is_random_access_iterator : std::false_type +{ +}; + +template +struct is_random_access_iterator() + 1), + decltype(std::declval() - 1), + decltype(std::declval()[1])>> + : std::bool_constant> +{ +}; + +template +inline constexpr bool is_random_access_iterator_v = is_random_access_iterator::value; + +template +struct is_range : std::false_type +{ +}; + +template +struct is_range())), + decltype(end(std::declval())), + decltype(begin(std::declval()) != end(std::declval()))>> + : std::bool_constant()))>>> +{ +}; + +template +inline constexpr bool is_range_v = is_range::value; + +template +struct is_sized_range : std::false_type +{ +}; + +template +struct is_sized_range()))>> + : std::bool_constant> +{ +}; + +template +inline constexpr bool is_sized_range_v = is_sized_range::value; + +template +struct is_bidirectional_range : std::false_type +{ +}; + +template +struct is_bidirectional_range> + : std::bool_constant< + is_range_v && + is_bidirectional_iterator_v()))>>> +{ +}; + +template +inline constexpr bool is_bidirectional_range_v = is_bidirectional_range::value; + +template +struct is_random_access_range : std::false_type +{ +}; + +template +struct is_random_access_range> + : std::bool_constant< + is_range_v && + is_random_access_iterator_v()))>>> +{ +}; + +template +inline constexpr bool is_random_access_range_v = is_random_access_range::value; + +template +class to_array_proxy +{ + static_assert(is_range_v); + + public: + explicit to_array_proxy(const Range& source) noexcept : source_(source) {} + + template + operator std::array() const + { + std::array destination; + + std::copy_n(std::begin(source_), + std::min(Size, std::size(source_)), + std::begin(destination)); + + return destination; + } + + private: + const Range& source_; +}; + +} // namespace detail + +template +inline auto to_array(Range& range) noexcept + -> std::enable_if_t, + detail::to_array_proxy>> +{ + return detail::to_array_proxy>{range}; +} + +template +inline auto is_valid_axes(const Axes& axes) + -> std::enable_if_t, bool> +{ + using std::empty; + if(empty(axes)) + { + return false; + } + + using std::begin, std::end; + std::vector sorted_axes(begin(axes), end(axes)); + + std::sort(begin(sorted_axes), end(sorted_axes)); + const auto last = std::unique(begin(sorted_axes), end(sorted_axes)); + + return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) && + (*std::prev(last) == size(axes) - 1); +} + +template +inline auto is_valid_shape(const Shape& shape) -> std::enable_if_t, bool> +{ + static_assert(std::is_unsigned_v>); + + using std::begin, std::end; + using std::empty; + return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; }); +} + +template +inline auto is_valid_indices(const Shape& shape, const Indices& indices) + -> std::enable_if_t && detail::is_sized_range_v, bool> +{ + static_assert(std::is_unsigned_v>); + + if(!is_valid_shape(shape)) + { + return false; + } + + using std::empty; + if(empty(indices)) + { + return false; + } + + using std::size; + if(size(shape) != size(indices)) + { + return false; + } + + using std::begin, std::end; + + auto dim = begin(shape); + auto idx = begin(indices); + for(; dim != end(shape) && idx != end(indices); ++dim, ++idx) + { + if(*dim <= *idx) + { + return false; + } + } + + return true; +} + +template +std::array transpose(const std::array& shape, + const std::array& axes) +{ + assert(is_valid_shape(shape) && is_valid_axes(axes)); + + std::array transposed; + auto iter = std::begin(transposed); + for(const auto axis : axes) + { + *iter++ = shape[axis]; + } + + return transposed; +} + +auto extend_shape(const Problem::Shape& shape, std::size_t new_dim) +{ + detail::enlarge_array_size_t extended_shape; + + using std::begin, std::end; + + ck::ranges::copy(shape, begin(extended_shape)); + extended_shape.back() = new_dim; + + return extended_shape; +} + +auto extend_axes(const Problem::Axes& axes) +{ + detail::enlarge_array_size_t extended_axes; + + using std::begin, std::end; + + ck::ranges::copy(axes, begin(extended_axes)); + extended_axes.back() = detail::get_array_size_v; + + return extended_axes; +} + +template +auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> +{ + using std::size; + if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) + { + return false; + } + + bool carry = true; + + using std::rbegin, std::rend; + auto dim = rbegin(shape); + auto idx = rbegin(indices); + for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx) + { + *idx = (*idx + carry); + carry = ((*idx == *dim) ? (*idx = 0, true) : false); + } + + return !carry; +} + +template +auto host_permute(const Tensor& src, const Axes& axes, Functor functor, Tensor& dest) + -> std::enable_if_t && detail::is_sized_range_v && + std::is_invocable_v, + std::add_lvalue_reference_t>, + bool> +{ + const auto& shape = src.mDesc.GetLengths(); + const auto& transposed_shape = dest.mDesc.GetLengths(); + if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape))) + { + return false; + } + + using std::size; + if(!is_valid_axes(axes)) + { + return false; + } + + static_assert(detail::is_sized_range_v> && + detail::is_sized_range_v>); + + if(size(shape) != size(transposed_shape)) + { + return false; + } + + static_assert(detail::is_random_access_range_v> && + detail::is_random_access_range_v>); + { + for(std::size_t idx = 0; idx < size(shape); ++idx) + { + if(transposed_shape[idx] != shape[axes[idx]]) + { + return false; + } + } + } + + std::vector indices(size(shape), 0); + if(!is_valid_indices(shape, indices)) + { + return false; + } + + switch(size(shape)) + { + case 3: { + do + { + Dest output = 0; + functor(output, src(indices[0], indices[1], indices[2])); + dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output; + } while(advance_indices(shape, indices)); + } + break; + case 4: { + do + { + Dest output = 0; + functor(output, src(indices[0], indices[1], indices[2], indices[3])); + dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output; + } while(advance_indices(shape, indices)); + } + break; + default: return false; + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/39_permute/permute_1xHxW_fp16.cpp b/3rdparty/composable_kernel/example/39_permute/permute_1xHxW_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7f9b80544a452ca7ea062d7d6135b92a2bf19ea --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/permute_1xHxW_fp16.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = F16; +using OutDataType = F16; + +// clang-format off +using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl +// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| +// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | +// ######| | | | | | | | | | | | | | | | + < 3, InDataType, OutDataType, PassThrough, 256, 1, 32, 32, 3, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 2, 1>; +// clang-format on + +#include "run_permute_element_example.inc" + +int main() { return !run_permute_element_example({1, 32000, 80}, {0, 2, 1}); } diff --git a/3rdparty/composable_kernel/example/39_permute/permute_HxWx4_fp16.cpp b/3rdparty/composable_kernel/example/39_permute/permute_HxWx4_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..342aa134ec5570c84447b137e5be165c9a5f697c --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/permute_HxWx4_fp16.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using DataType = F16; +using BundleType = F64; + +static_assert(sizeof(BundleType) % sizeof(DataType) == 0); + +// clang-format off +using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl +// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| +// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | +// ######| | | | | | | | | | | | | | | | + < 3, BundleType, BundleType, PassThrough, 256, 1, 32, 32, 5, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 4, 1>; +// clang-format on + +#include "run_permute_bundle_example.inc" + +int main() { return !run_permute_bundle_example({1, 80, 32000}, {0, 2, 1}); } diff --git a/3rdparty/composable_kernel/example/39_permute/permute_NxHxW_fp16.cpp b/3rdparty/composable_kernel/example/39_permute/permute_NxHxW_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b53975eb2c8632583f61afde50a03eba32c17c9b --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/permute_NxHxW_fp16.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = F16; +using OutDataType = F16; + +// clang-format off +using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl +// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst| +// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | +// ######| | | | | | | | | | | | | | | | + < 3, InDataType, OutDataType, PassThrough, 128, 4, 16, 8, 6, S<2, 16, 4>, S<0, 1, 2>, 2, 1, 2, 1>; +// clang-format on + +#include "run_permute_element_example.inc" + +int main() { return !run_permute_element_example({121, 768, 80}, {0, 2, 1}); } diff --git a/3rdparty/composable_kernel/example/39_permute/run_permute_bundle_example.inc b/3rdparty/composable_kernel/example/39_permute/run_permute_bundle_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..70406d63f91b95f4d1cce025051a74a0ee3d114e --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/run_permute_bundle_example.inc @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_permute_bundle(const Problem& problem) +{ + const auto& input_bundle_shape = problem.shape; + const auto& input_bundle_axes = problem.axes; + + const auto output_bundle_shape = transpose(input_bundle_shape, input_bundle_axes); + + Tensor input_bundle_tensor(input_bundle_shape); + Tensor output_bundle_tensor(output_bundle_shape); + + // initialize tensor by assigning DataType values + ck::utils::FillUniformDistribution{-1.f, 1.f}(input_bundle_tensor.AsSpan()); + + DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes()); + DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes()); + + using std::data; + input_device_buf.ToDevice(data(input_bundle_tensor)); + + static_assert(std::is_default_constructible_v); + + auto permute = DevicePermuteInstance{}; + auto argument = permute.MakeArgument(to_array(input_bundle_shape), + to_array(input_bundle_tensor.GetStrides()), + to_array(output_bundle_shape), + to_array(output_bundle_tensor.GetStrides()), + input_device_buf.GetDeviceBuffer(), + output_device_buf.GetDeviceBuffer(), + PassThrough{}); + + if(!permute.IsSupportedArgument(argument)) + { + std::cerr << "The runtime parameters seems not supported by the device instance, exiting!" + << std::endl; + return false; + }; + + auto invoker = permute.MakeInvoker(); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + output_device_buf.FromDevice(data(output_bundle_tensor)); + + constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType); + + // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle] + // axes from [0, 2, 1] to [0, 2, 1, 3] + const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle); + const auto input_axes = extend_axes(input_bundle_axes); + + using std::begin; + + Tensor input_tensor(input_shape); + ck::ranges::copy(input_bundle_tensor.AsSpan(), begin(input_tensor)); + + Tensor output_tensor(transpose(input_shape, input_axes)); + if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor)) + { + return false; + } + + return ck::utils::check_err(output_bundle_tensor.AsSpan(), + output_tensor.AsSpan(), + "Error: incorrect results in output tensor", + 1e-6, + 1e-6); +} + +bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes) +{ + return run_permute_bundle(Problem{shape, axes}); +} diff --git a/3rdparty/composable_kernel/example/39_permute/run_permute_element_example.inc b/3rdparty/composable_kernel/example/39_permute/run_permute_element_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..bc6235303039c7cb33a4057446137a9f18ddb186 --- /dev/null +++ b/3rdparty/composable_kernel/example/39_permute/run_permute_element_example.inc @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_permute_element(const Problem& problem) +{ + const auto& input_shape = problem.shape; + const auto& input_axes = problem.axes; + + const auto output_shape = transpose(input_shape, input_axes); + + Tensor input_tensor(input_shape); + Tensor output_tensor(output_shape); + + ck::utils::FillUniformDistribution{-1.f, 1.f}(input_tensor); + + DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes()); + DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes()); + + using std::data; + input_device_buf.ToDevice(data(input_tensor)); + + static_assert(std::is_default_constructible_v); + + auto permute = DevicePermuteInstance{}; + auto argument = permute.MakeArgument(to_array(input_shape), + to_array(input_tensor.GetStrides()), + to_array(output_shape), + to_array(output_tensor.GetStrides()), + input_device_buf.GetDeviceBuffer(), + output_device_buf.GetDeviceBuffer(), + PassThrough{}); + + if(!permute.IsSupportedArgument(argument)) + { + std::cerr << "The runtime parameters seems not supported by the device instance, exiting!" + << std::endl; + return false; + }; + + auto invoker = permute.MakeInvoker(); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + output_device_buf.FromDevice(data(output_tensor)); + + Tensor output_tensor_host(output_shape); + if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor_host)) + { + return false; + } + + return ck::utils::check_err(output_tensor.AsSpan(), + output_tensor_host.AsSpan(), + "Error: incorrect results in output tensor", + 1e-6, + 1e-6); +} + +bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes) +{ + return run_permute_element(Problem{shape, axes}); +} diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cb30f61760a47a2f72c98e31de2b522b97543ca --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -0,0 +1,8 @@ +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) + +if(USE_BITINT_EXTENSION_INT4) +add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2aea08c40004fac22ea0e45521ad09fd806751d8 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using In0DataType = ck::bhalf_t; +using Wei0DataType = ck::bhalf_t; +using Acc0DataType = float; +using Wei1DataType = ck::bhalf_t; +using Acc1DataType = float; +using C1ShuffleDataType = float; +using Out1DataType = ck::bhalf_t; + +// This is used for reference code +using Out0DataType = ck::bhalf_t; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + In0DataType, // ADataType, + Wei0DataType, // B0DataType, + Wei1DataType, // B1DataType, + Out1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#include "run_grouped_conv_conv_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7f80e76d6ce5a690beef4f48361e0a65402cb47 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using In0DataType = ck::half_t; +using Wei0DataType = ck::half_t; +using Acc0DataType = float; +using Wei1DataType = ck::half_t; +using Acc1DataType = float; +using C1ShuffleDataType = float; +using Out1DataType = ck::half_t; + +// This is used for reference code +using Out0DataType = ck::half_t; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + In0DataType, // ADataType, + Wei0DataType, // B0DataType, + Wei1DataType, // B1DataType, + Out1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#include "run_grouped_conv_conv_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15e460948ef528b697a4e3bad7d40ca332f85c15 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using In0DataType = float; +using Wei0DataType = float; +using Acc0DataType = float; +using Wei1DataType = float; +using Acc1DataType = float; +using C1ShuffleDataType = float; +using Out1DataType = float; + +// This is used for reference code +using Out0DataType = float; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + In0DataType, // ADataType, + Wei0DataType, // B0DataType, + Wei1DataType, // B1DataType, + Out1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 128, // Gemm1NPerBlock + 16, // Gemm1KPerBlock + 4, // AK1 + 4, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#include "run_grouped_conv_conv_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2cc4c07c0d89fc83b233b50967cd595dc0e53ac7 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#error Should compile this file with ck::int4_t support +#endif + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using In0DataType = ck::int4_t; +using Wei0DataType = ck::int4_t; +using KernelIn0DataType = int8_t; +using KernelWei0DataType = int8_t; +using Acc0DataType = int32_t; +using Wei1DataType = ck::int4_t; +using KernelWei1DataType = int8_t; +using Acc1DataType = int32_t; +using C1ShuffleDataType = int32_t; +using Out1DataType = ck::int4_t; +using KernelOut1DataType = int8_t; + +// This is used for reference code +using Out0DataType = ck::int4_t; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + KernelIn0DataType, // ADataType, + KernelWei0DataType, // B0DataType, + KernelWei1DataType, // B1DataType, + KernelOut1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 128, // Gemm1NPerBlock + 64, // Gemm1KPerBlock + 16, // AK1 + 16, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#define BUILD_INT4_EXAMPLE +#include "run_grouped_conv_conv_fwd_example.inc" + +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) +static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + +int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40ff0f69cc14cd644361f3391eea8ed52380adc0 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using In0DataType = int8_t; +using Wei0DataType = int8_t; +using Acc0DataType = int32_t; +using Wei1DataType = int8_t; +using Acc1DataType = int32_t; +using C1ShuffleDataType = int32_t; +using Out1DataType = int8_t; + +// This is used for reference code +using Out0DataType = int8_t; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + In0DataType, // ADataType, + Wei0DataType, // B0DataType, + Wei1DataType, // B1DataType, + Out1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 128, // Gemm1NPerBlock + 64, // Gemm1KPerBlock + 16, // AK1 + 16, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#include "run_grouped_conv_conv_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } diff --git a/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..a2c97f4d421f9382aa279b257a49299c4016e2c3 --- /dev/null +++ b/3rdparty/composable_kernel/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_grouped_conv_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv0_param, + const ck::utils::conv::ConvParam& conv1_param, + const HostTensorDescriptor& in0_g_n_c_wis_desc, + const HostTensorDescriptor& wei0_g_k_c_xs_desc, + const HostTensorDescriptor& out0_g_n_k_wos_desc, + const HostTensorDescriptor& wei1_g_k_c_xs_desc, + const HostTensorDescriptor& out1_g_n_k_wos_desc, + const In0ElementOp& in0_element_op, + const Wei0ElementOp& wei0_element_op, + const Wei1ElementOp& wei1_element_op, + const Out0ElementOp& out0_element_op, + const Out1ElementOp& out1_element_op) +{ + Tensor in0(in0_g_n_c_wis_desc); + Tensor wei0(wei0_g_k_c_xs_desc); + Tensor wei1(wei1_g_k_c_xs_desc); + Tensor out1_host(out1_g_n_k_wos_desc); + Tensor out1_device(out1_g_n_k_wos_desc); + + std::cout << "in0: " << in0.mDesc << std::endl; + std::cout << "wei0: " << wei0.mDesc << std::endl; + std::cout << "wei1: " << wei1.mDesc << std::endl; + std::cout << "out1: " << out1_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in0.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei0.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei1.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei0.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + wei1.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem in0_device_buf(sizeof(KernelIn0DataType) * in0.mDesc.GetElementSpaceSize()); + DeviceMem wei0_device_buf(sizeof(KernelWei0DataType) * wei0.mDesc.GetElementSpaceSize()); + DeviceMem wei1_device_buf(sizeof(KernelWei1DataType) * wei1.mDesc.GetElementSpaceSize()); + DeviceMem out1_device_buf(sizeof(KernelOut1DataType) * out1_device.mDesc.GetElementSpaceSize()); + + const Tensor in0_converted(in0); + const Tensor wei0_converted(wei0); + const Tensor wei1_converted(wei1); + + in0_device_buf.ToDevice(in0_converted.mData.data()); + wei0_device_buf.ToDevice(wei0_converted.mData.data()); + wei1_device_buf.ToDevice(wei1_converted.mData.data()); +#else + DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize()); + DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize()); + DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize()); + DeviceMem out1_device_buf(sizeof(Out1DataType) * out1_device.mDesc.GetElementSpaceSize()); + + in0_device_buf.ToDevice(in0.mData.data()); + wei0_device_buf.ToDevice(wei0.mData.data()); + wei1_device_buf.ToDevice(wei1.mData.data()); +#endif + + std::array a0_g_n_c_wis_lengths{}; + std::array a0_g_n_c_wis_strides{}; + std::array b0_g_k_c_xs_lengths{}; + std::array b0_g_k_c_xs_strides{}; + std::array b1_g_k_c_xs_lengths{}; + std::array b1_g_k_c_xs_strides{}; + std::array e1_g_n_k_wos_lengths{}; + std::array e1_g_n_k_wos_strides{}; + std::array conv0_filter_strides{}; + std::array conv0_filter_dilations{}; + std::array input0_left_pads{}; + std::array input0_right_pads{}; + std::array conv1_filter_strides{}; + std::array conv1_filter_dilations{}; + std::array input1_left_pads{}; + std::array input1_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths); + copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides); + copy(wei0_g_k_c_xs_desc.GetLengths(), b0_g_k_c_xs_lengths); + copy(wei0_g_k_c_xs_desc.GetStrides(), b0_g_k_c_xs_strides); + copy(wei1_g_k_c_xs_desc.GetLengths(), b1_g_k_c_xs_lengths); + copy(wei1_g_k_c_xs_desc.GetStrides(), b1_g_k_c_xs_strides); + copy(out1_g_n_k_wos_desc.GetLengths(), e1_g_n_k_wos_lengths); + copy(out1_g_n_k_wos_desc.GetStrides(), e1_g_n_k_wos_strides); + copy(conv0_param.conv_filter_strides_, conv0_filter_strides); + copy(conv0_param.conv_filter_dilations_, conv0_filter_dilations); + copy(conv0_param.input_left_pads_, input0_left_pads); + copy(conv0_param.input_right_pads_, input0_right_pads); + copy(conv1_param.conv_filter_strides_, conv1_filter_strides); + copy(conv1_param.conv_filter_dilations_, conv1_filter_dilations); + copy(conv1_param.input_left_pads_, input1_left_pads); + copy(conv1_param.input_right_pads_, input1_right_pads); + + // do Conv using GEMM, only works for 1x1 conv for now + const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0]; + + const ck::index_t gemm0_m_length = + e1_g_n_k_wos_lengths[1] * + ck::accumulate_n( + e1_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>{}); + + const ck::index_t gemm0_n_length = b0_g_k_c_xs_lengths[1]; + + const ck::index_t gemm0_k_length = ck::accumulate_n( + b0_g_k_c_xs_lengths.begin() + 2, NDimSpatial + 1, 1, std::multiplies<>{}); + + const ck::index_t gemm1_n_length = b1_g_k_c_xs_lengths[1]; + + // + const ck::index_t a0_stride = a0_g_n_c_wis_strides[2 + NDimSpatial]; + const ck::index_t b0_stride = b0_g_k_c_xs_strides[2 + NDimSpatial]; + const ck::index_t b1_stride = b1_g_k_c_xs_strides[2 + NDimSpatial]; + const ck::index_t e1_stride = e1_g_n_k_wos_strides[2 + NDimSpatial]; + + // + const ck::index_t a0_batch_stride = a0_g_n_c_wis_strides[0]; + const ck::index_t b0_batch_stride = b0_g_k_c_xs_strides[0]; + const ck::index_t b1_batch_stride = b1_g_k_c_xs_strides[0]; + const ck::index_t e1_batch_stride = e1_g_n_k_wos_strides[0]; + + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(in0_device_buf.GetDeviceBuffer()), + static_cast(wei0_device_buf.GetDeviceBuffer()), + static_cast(wei1_device_buf.GetDeviceBuffer()), + static_cast(out1_device_buf.GetDeviceBuffer()), +#else + static_cast(in0_device_buf.GetDeviceBuffer()), + static_cast(wei0_device_buf.GetDeviceBuffer()), + static_cast(wei1_device_buf.GetDeviceBuffer()), + static_cast(out1_device_buf.GetDeviceBuffer()), +#endif + gemm0_m_length, + gemm0_n_length, + gemm0_k_length, + gemm1_n_length, + gemm_batch, + a0_stride, + b0_stride, + b1_stride, + e1_stride, + a0_batch_stride, + b0_batch_stride, + b1_batch_stride, + e1_batch_stride, + in0_element_op, + wei0_element_op, + out0_element_op, + wei1_element_op, + out1_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv0_param.GetFlops() + conv1_param.GetFlops(); + std::size_t num_btype = conv0_param.template GetInputByte() + + conv0_param.template GetWeightByte() + + conv1_param.template GetWeightByte() + + conv1_param.template GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + Tensor out0_host(out0_g_n_k_wos_desc); + + auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_conv0_invoker = ref_conv0.MakeInvoker(); + auto ref_conv1_invoker = ref_conv1.MakeInvoker(); + + auto ref_conv0_argument = ref_conv0.MakeArgument(in0, + wei0, + out0_host, + conv0_param.conv_filter_strides_, + conv0_param.conv_filter_dilations_, + conv0_param.input_left_pads_, + conv0_param.input_right_pads_, + in0_element_op, + wei0_element_op, + out0_element_op); + + auto ref_conv1_argument = ref_conv1.MakeArgument(out0_host, + wei1, + out1_host, + conv1_param.conv_filter_strides_, + conv1_param.conv_filter_dilations_, + conv1_param.input_left_pads_, + conv1_param.input_right_pads_, + out0_element_op, + wei1_element_op, + out1_element_op); + + ref_conv0_invoker.Run(ref_conv0_argument); + ref_conv1_invoker.Run(ref_conv1_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor out1_device_converted(out1_host.mDesc); + + out1_device_buf.FromDevice(out1_device_converted.mData.data()); + + out1_device = out1_device_converted.CopyAsType(); +#else + out1_device_buf.FromDevice(out1_device.mData.data()); +#endif + + return ck::utils::check_err( + out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return true; +} + +bool run_grouped_conv_conv_fwd_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv0_param{ + 2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + ck::utils::conv::ConvParam conv1_param{ + 2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + const auto in0_element_op = In0ElementOp{}; + const auto wei0_element_op = Wei0ElementOp{}; + const auto wei1_element_op = Wei1ElementOp{}; + const auto out0_element_op = Out0ElementOp{}; + const auto out1_element_op = Out1ElementOp{}; + + const auto run = [&](auto ndim_spatial, + auto in0_layout, + auto wei0_layout, + auto wei1_layout, + auto out1_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using In0Layout = decltype(in0_layout); + using Wei0Layout = decltype(wei0_layout); + using Wei1Layout = decltype(wei1_layout); + using Out1Layout = decltype(out1_layout); + + const auto in0_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv0_param); + + const auto wei0_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv0_param); + + // out0 doesn't physical exist, any layout for host verification is OK + const auto out0_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv0_param); + + const auto wei1_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv1_param); + + const auto out1_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv1_param); + + return run_grouped_conv_conv_fwd(do_verification, + init_method, + time_kernel, + conv0_param, + conv1_param, + in0_g_n_c_wis_desc, + wei0_g_k_c_xs_desc, + out0_g_n_k_wos_desc, + wei1_g_k_c_xs_desc, + out1_g_n_k_wos_desc, + in0_element_op, + wei0_element_op, + wei1_element_op, + out0_element_op, + out1_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv0_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{}); + } + else if(conv0_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{}); + } + else if(conv0_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/3rdparty/composable_kernel/example/42_groupnorm/CMakeLists.txt b/3rdparty/composable_kernel/example/42_groupnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c3b7b825920314aba3fa15e401da26860c23a3c3 --- /dev/null +++ b/3rdparty/composable_kernel/example/42_groupnorm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp) diff --git a/3rdparty/composable_kernel/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/3rdparty/composable_kernel/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e62001d6692d6dd504078cd12639167641f3cc24 --- /dev/null +++ b/3rdparty/composable_kernel/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using AccDataType = float; + +struct YElementOp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + T a; + + ck::tensor_operation::element_wise::Sigmoid{}(a, x); + + y = x * a; + }; +}; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t H = 32; + ck::index_t W = 32; + ck::index_t G = 32; + ck::index_t C = 30; + + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + N = std::stoi(argv[1]); + H = std::stoi(argv[2]); + W = std::stoi(argv[3]); + G = std::stoi(argv[4]); + C = std::stoi(argv[5]); + } + else + { + std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + + return 1; + } + + Tensor x({N, H, W, G, C}); + Tensor y({N, H, W, G, C}); + Tensor gamma({G, C}); + Tensor beta({G, C}); + + ck::utils::FillUniformDistribution{0.f, 1.f}(x); + ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); + ck::utils::FillUniformDistribution{0.f, 1.f}(beta); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + const auto y_element_op = YElementOp{}; + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {N, H, W, G, C}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 0, 0, C, 1}, + {0, 0, 0, C, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + {1, 2, 4}, // reduction dimension: [H, W, C] + 1e-6, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + y_element_op); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); + + std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C + + sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C + + sizeof(BetaDataType) * G * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, " + << device_instance.GetTypeString() << std::endl; + + bool pass = true; + { + Tensor host_y({N, H, W, G, C}); + using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c29f18f1627ef20fe69cb3751d3a7766ffbd236c --- /dev/null +++ b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) diff --git a/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ac4b68272e2b0a1a5cbc09fa3734e190d6897da --- /dev/null +++ b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int split_k = 1; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 256; + + ck::index_t N0 = 16; + ck::index_t N1 = 128; + + ck::index_t K0 = 64 * 2; + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + split_k = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a_gs_ms_ks( + std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + Tensor b_gs_ns_ks( + std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + Tensor d_gs_ms_ns( + std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + Tensor e_gs_ms_ns_host_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + Tensor e_gs_ms_ns_device_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{1}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(), + e_gs_ms_ns_lengths.begin() + NumDimG, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG, + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + e_gs_ms_ns_host_result.ForEach([&](auto&, auto idx) { + cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx)); + }); + + return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..764e55ef558bd977fca8f25f16b2bba4c111114b --- /dev/null +++ b/3rdparty/composable_kernel/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|CBlockTransferClusterLengths| CBlockTransfer| + //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int split_k = 1; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 256; + + ck::index_t N0 = 16; + ck::index_t N1 = 128; + + ck::index_t K0 = 64 * 2; + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + split_k = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a_gs_ms_ks( + std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + Tensor b_gs_ns_ks( + std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + Tensor d_gs_ms_ns( + std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + Tensor e_gs_ms_ns_host_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + Tensor e_gs_ms_ns_device_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{1}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(), + e_gs_ms_ns_lengths.begin() + NumDimG, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG, + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN, + ck::index_t{1}, + std::multiplies{}); + + ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK, + ck::index_t{1}, + std::multiplies{}); + + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + e_gs_ms_ns_host_result.ForEach([&](auto&, auto idx) { + cde_element_op(e_gs_ms_ns_host_result(idx), c_ms_ns_host_result(idx), d_gs_ms_ns(idx)); + }); + + return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/CMakeLists.txt b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f02e5110d071e3a43b057c5a7c4b927373c1646b --- /dev/null +++ b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) diff --git a/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..832665edc055db434ce14044c15fbded8b51b847 --- /dev/null +++ b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_k_desc, + const HostTensorDescriptor& requant_scale_g_k_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_k_desc); + Tensor requant_scale(requant_scale_g_k_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "requant_scale: " << requant_scale.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + wei.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + bias.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + requant_scale.GenerateTensorValue(GeneratorTensor_2{0, 1}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem requant_scale_device_buf(sizeof(RequantScaleDataType) * + requant_scale.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + requant_scale_device_buf.ToDevice(requant_scale.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array d1_g_n_k_wos_lengths{}; + std::array d1_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(requant_scale_g_k_desc.GetLengths(), d1_g_n_k_wos_lengths); + copy(requant_scale_g_k_desc.GetStrides(), d1_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer(), requant_scale_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}, + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + out_element_op(out_host(idx), c_host(idx), bias(idx), requant_scale(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using BiasLayout = ck::tensor_layout::convolution::G_K; + using RequantScaleLayout = ck::tensor_layout::convolution::G_K; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + // TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed() + const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + + const auto requant_scale_g_k_desc = bias_g_k_desc; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::cout << out_g_n_k_wos_desc << std::endl; + + using deviceOp = DeviceGroupedConvNDFwdInstance; + + return run_grouped_conv_fwd(do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + bias_g_k_desc, + requant_scale_g_k_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5401350352ee0506f7fb99ceae1fcc52e22833a --- /dev/null +++ b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_k_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_k_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {d0_g_n_k_wos_lengths}, + {d0_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach( + [&](auto&, auto idx) { out_element_op(out_host(idx), c_host(idx), bias(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using BiasLayout = ck::tensor_layout::convolution::G_K; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + // TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed() + const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::cout << out_g_n_k_wos_desc << std::endl; + + return run_grouped_conv_fwd< + ndim_spatial, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + bias_g_k_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d46d86655cdbaa6afc07a10dc50a5ce2a809bfe --- /dev/null +++ b/3rdparty/composable_kernel/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 16>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + return run_grouped_conv_fwd< + ndim_spatial, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/3rdparty/composable_kernel/example/44_elementwise_permute/CMakeLists.txt b/3rdparty/composable_kernel/example/44_elementwise_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0e0091a986ba6f92f49eb83dc8ae068636c7385e --- /dev/null +++ b/3rdparty/composable_kernel/example/44_elementwise_permute/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) diff --git a/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bbdbe52b9a7d813ec96cd18301e785a4d4e075a --- /dev/null +++ b/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -0,0 +1,116 @@ +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + PassThrough, + 4, + 8, + ck::Sequence<8>, + ck::Sequence<1>>; + +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) +{ + for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) + for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) + for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) + for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) + { + auto a_val = A_nchw(n, c, h, w); + functor(B_nhwc(n, h, w, c), a_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::vector nhwc = {16, 32, 64, 128}; + Tensor a(nchw); + Tensor b(nhwc); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + std::array b_strides = {static_cast(nhwc[1] * nhwc[2] * nhwc[3]), + 1, + static_cast(nhwc[2] * nhwc[3]), + static_cast(nhwc[3])}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(nhwc); + host_elementwise4D(host_b, a, PassThrough{}); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f16ad3b3c5c4c5b3aa7f53bb63c7a0149dd7584c --- /dev/null +++ b/3rdparty/composable_kernel/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -0,0 +1,130 @@ +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_2d.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + PassThrough, + 3, // NumDim_M + 1, // NumDim_N + 8, + 8, + ck::Sequence<8>, + ck::Sequence<8>>; + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + const std::vector& shape_nchw, + Functor functor) +{ + for(std::size_t n = 0; n < shape_nchw[0]; ++n) + for(std::size_t c = 0; c < shape_nchw[1]; ++c) + for(std::size_t h = 0; h < shape_nchw[2]; ++h) + for(std::size_t w = 0; w < shape_nchw[3]; ++w) + { + auto a_val = A_nchw(n, c, h, w); + functor(B_nhwc(n, h, w, c), a_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + const int N = 120; + const int C = 128; + const int H = 32; + const int W = 1024; + + /**const int N = 120; + const int H = 32; + const int W = 64; + + const int C = 128;**/ + + std::vector nchw = {N, C, H, W}; + std::vector nhwc = {N, H, W, C}; + + Tensor a(nchw); + Tensor b(nhwc); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + // LogRangeAsType(std::cout << "Tensor a : ", a.mData, ",") << std::endl; + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths{N, H, W, C}; + + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + // LogRangeAsType(std::cout << "Tensor b : ", b.mData, ",") << std::endl; + + Tensor host_b(nhwc); + host_elementwise4D, Tensor, PassThrough>( + host_b, a, nchw, PassThrough{}); + + // LogRangeAsType(std::cout << "Host b : ", host_b.mData, ",") << std::endl; + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/example/45_elementwise_normalization/CMakeLists.txt b/3rdparty/composable_kernel/example/45_elementwise_normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8f5b9d4d87841b913fabc29971b6ae3fc197bde1 --- /dev/null +++ b/3rdparty/composable_kernel/example/45_elementwise_normalization/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_elementwise_layernorm_blockwise elementwise_layernorm_blockwise.cpp) diff --git a/3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12f6bc6ffeab8f08e0ec9851fe3e3654b05d8068 --- /dev/null +++ b/3rdparty/composable_kernel/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" + +using ADataType = ck::half_t; // Input 1 +using BDataType = ck::half_t; // Input 2 +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using AccDataType = float; +using XElementwiseOperation = ck::tensor_operation::element_wise::Add; +using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +// X = Elementwise(input1, input2, input3, ...) +// Y = Layernorm(X, beta, gamma) +using DeviceInstance = ck::tensor_operation::device::DeviceElementwiseNormalizationImpl< + ck::Tuple, + GammaDataType, + BetaDataType, + AccDataType, + YDataType, + XElementwiseOperation, + YElementwiseOperation, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterM + 32, // ClusterK + 1, // SliceM + 32, // SliceK + 1, // SrcVecDim (0=M, 1=K) + 8, // SrcScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 8, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 8, // BetaScalarPerVector + 8>; // OutScalarPerVector + +template +void host_elementwise2D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + { + auto a_val = A(m, n); + auto b_val = B(m, n); + ctype c_val = 0; + functor(c_val, a_val, b_val); + C(m, n) = c_val; + } +} + +int main() +{ + bool time_kernel = true; + + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + ck::index_t Stride = N; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + }; + + Tensor a(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor b(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor gamma(f_host_tensor_descriptor1d(N, 1)); + Tensor beta(f_host_tensor_descriptor1d(N, 1)); + Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); + + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + beta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_dev(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_dev(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + a_dev.ToDevice(a.mData.data()); + b_dev.ToDevice(b.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + std::array input = {a_dev.GetDeviceBuffer(), b_dev.GetDeviceBuffer()}; + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {M, N}, + { + std::vector{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, + std::vector{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, + }, + {0, 1}, + {0, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + {1}, + 1e-4, + input, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + XElementwiseOperation{}, + YElementwiseOperation{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float ela_time = 0; + ela_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + float data_mem_size = M * N * sizeof(ADataType) + M * N * sizeof(BDataType) + + M * N * sizeof(YDataType) + N * sizeof(GammaDataType) + + N * sizeof(BetaDataType); + float bandwidth = data_mem_size * 1000 / ela_time / 1024 / 1024 / 1024; + + std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl; + std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl; + + bool pass = true; + { + std::vector mn = {static_cast(M), + static_cast(N)}; + Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); + host_elementwise2D, + Tensor, + Tensor, + XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); + + Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + if(!(pass)) + { + std::cout << "layernorm wrong" << std::endl; + } + } + return (pass ? 0 : 1); +} diff --git a/3rdparty/composable_kernel/example/CMakeLists.txt b/3rdparty/composable_kernel/example/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..02a77f952c40a47e24b0d9bc9fbe1b34913e8292 --- /dev/null +++ b/3rdparty/composable_kernel/example/CMakeLists.txt @@ -0,0 +1,34 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + +add_custom_target(examples) + +function(add_example_executable EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + # HC + target_compile_options(${EXAMPLE_NAME} PRIVATE --gpu-max-threads-per-block=1024) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) +endfunction(add_example_executable EXAMPLE_NAME) + +function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) +endfunction(add_example_executable_no_testing EXAMPLE_NAME) + +# add all example subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + IF(IS_DIRECTORY "${subdir}") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/3rdparty/composable_kernel/include/ck/ck.hpp b/3rdparty/composable_kernel/include/ck/ck.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0e71df0046eb3413cf6567dd32ee3b60fbcf3f39 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/ck.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif + +#define CK_TIME_KERNEL 1 + +// constant address space for kernel parameter +// https://llvm.org/docs/AMDGPUUsage.html#address-spaces +#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 +#endif + +// check GPU target +#ifdef __HIP_DEVICE_COMPILE__ +#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__)) +#error Not supported target +#endif +#endif + +// buffer resource +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_BUFFER_RESOURCE_3RD_DWORD -1 +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx908__) || \ + defined(__gfx90a__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx1030__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx1100__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 +#endif + +// FMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing +#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code +#define CK_USE_AMD_V_MAC_F32 +#elif defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx1030__) // for GPU code +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#endif + +// MFMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_MFMA +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_MFMA +#endif + +#if defined(__gfx90a__) +#define CK_USE_AMD_MFMA_BF16_1K_OP +#endif + +// WMMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_WMMA +#elif defined(__gfx1100__) // for GPU code +#define CK_USE_AMD_WMMA +#endif + +// buffer load +#define CK_USE_AMD_BUFFER_LOAD 1 + +// buffer store +#define CK_USE_AMD_BUFFER_STORE 1 + +// buffer atomic add: integer +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +// inline asm +#define CK_USE_AMD_INLINE_ASM 1 + +// inner product (DLOP) +#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 + +// experimental feature: multi index implemented as array +#define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// experimental feature: static tensor descriptor +#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 + +// experimental feature: buffer load/store/atomic-add/ OOB trick +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting. Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter for each usage +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 + +// experimental feature: in-regsiter sub-dword transpose +#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 + +// experimental feature: merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 + +// experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from +// pointer of scalar +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 + +// experimental feature: use __builtin_memcpy instead of union to do bit_cast +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 + +// experimental feature: optimize for inter-wave scheduling policy +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 1 +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 +// this will let make_default_loop_scheduler() return interwave scheduling flag by default +#define CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING 0 +// experimental feature: add instances using interwave scheduling +#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 +// experimental feature: add instances using pipeline v2 +#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 + +// workaround: compiler crash when compiling recursive lambda +#define CK_WORKAROUND_SWDEV_275126 1 + +// workaround: compiler crash when using buffer load/store for i8 +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 + +// workaround: compiler gnerating inefficient ds_write instructions +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 + +// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some +// tuning parameter +#define CK_WORKAROUND_SWDEV_325164 0 + +// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue +#ifdef __gfx908__ +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1 +#else // __gfx90a__, ... +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0 +#endif // __gfx908__ + +namespace ck { + +enum struct InMemoryDataOperationEnum +{ + Set, + AtomicAdd, + AtomicMax, + Add +}; + +// FIXME: use regular Sequence and remove this +template +struct InMemoryDataOperationEnumSequence +{ + static constexpr int mSize = sizeof...(Is); + + __host__ __device__ static constexpr InMemoryDataOperationEnum At(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set}; + return mData[I]; + } +}; + +// index type +using index_t = int32_t; +using long_index_t = int64_t; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/host_utility/device_prop.hpp b/3rdparty/composable_kernel/include/ck/host_utility/device_prop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e2cbdb733272d37aa5dbc7a746e86911d6b8644f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/host_utility/device_prop.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +namespace ck { + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; + return name; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/host_utility/hip_check_error.hpp b/3rdparty/composable_kernel/include/ck/host_utility/hip_check_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3dc8eaf1eb8b87207256ea4a521e23d5a49ca9c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/host_utility/hip_check_error.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} diff --git a/3rdparty/composable_kernel/include/ck/host_utility/io.hpp b/3rdparty/composable_kernel/include/ck/host_utility/io.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac8719592db8b48996bfbc65dc656d6d96bde545 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/host_utility/io.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const std::array& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor& desc) +{ + constexpr ck::index_t nDim = ck::remove_cvref_t::GetNumOfDimension(); + + os << "{"; + + ck::static_for<0, nDim - 1, 1>{}([&](auto i) { os << desc.GetLength(i) << ", "; }); + + os << desc.GetLength(ck::Number{}); + + os << "}"; + + return os; +} diff --git a/3rdparty/composable_kernel/include/ck/host_utility/kernel_launch.hpp b/3rdparty/composable_kernel/include/ck/host_utility/kernel_launch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed6e2f0ba1d0ca433487fa32f374fde3b7ecee0e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/host_utility/kernel_launch.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +template +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + const int nrepeat = 10; + + printf("Warm up 1 time\n"); + + // warm up + kernel<<>>(args...); + + printf("Start running %d times...\n", nrepeat); + + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + kernel<<>>(args...); + + return 0; + } +#else + kernel<<>>(args...); + + return 0; +#endif +} diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db8e48df6d4b18e9c8a60b8f3d4382d1c6e649f8 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,275 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Number of GEMMs = YTilde * XTilde +// GemmM = C +// GemmN = N * HTildeSlice * WTildeSlice +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilde = Number{}; + constexpr auto IXTilde = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilde), + make_freeze_transform(IXTilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IXTilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_pass_through_transform(C), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5391b595b5c1e9b5b7395d1a676efcd87b9f2e6e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: wei +// C: in +// Number of GEMMs = YTilde * XTilde +// GemmM = N * HTildeSlice * WTildeSlice +// GemmN = C +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + IYTilde i_ytilde, + IXTilde i_xtilde, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // B: weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +// A: out +// B: wei +// C: in +// Number of GEMMs = 1 +// GemmM = N * Ho * Wo +// GemmN = C +// GemmK = K +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bb1dc239f4d7e6a5abade2dda4a17a98ab7d1a47 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca530934e49dcccea388fd8501e5bf046f1c94af --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e960f90c4bb1974ddfaaf9faf6a12d75d1db2f4d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: output tensor + const auto out_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..052bab423db04740c88721026a3cceb3fa77c292 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: output tensor + const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c301a9e0c67b11c42fbe2a0a99c4f764049a5222 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: in +// C: wei +// GemmM = K +// GemmN = Y * X * C +// GemmKTotal = N * Ho * Wo +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41267536551ea298e111f6d93b8e6f3f8f9ed475 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Do * Ho * Wo +// GemmN = K +// GemmK = Z * Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + const TensorDescriptor& in_grid_desc_n_di_hi_wi_c, + const TensorDescriptor& wei_k_z_y_x_c_grid_desc, + const TensorDescriptor& out_n_do_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0); + const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4); + const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4); + + const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1); + const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2); + const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3); + + const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1); + const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2); + const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3); + + const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1); + const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2); + const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3); + + const auto ConvStrideD = conv_strides[I0]; + const auto ConvStrideH = conv_strides[I1]; + const auto ConvStrideW = conv_strides[I2]; + + const auto ConvDilationD = conv_dilations[I0]; + const auto ConvDilationH = conv_dilations[I1]; + const auto ConvDilationW = conv_dilations[I2]; + + const auto InLeftPadD = in_left_pads[I0]; + const auto InLeftPadH = in_left_pads[I1]; + const auto InLeftPadW = in_left_pads[I2]; + + const auto InRightPadD = in_right_pads[I0]; + const auto InRightPadH = in_right_pads[I1]; + const auto InRightPadW = in_right_pads[I2]; + + const auto GemmM = N * Do * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Z * Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}, Sequence<7>{})); + + const auto in_grid_desc_gemmk_gemmm = + transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_gemmk0_gemmm_gemmk1 = + transform_tensor_descriptor(in_grid_desc_gemmk_gemmm, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Z * Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_grid_desc_gemmk0_gemmn_gemmk1 = + transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + // out_n_do_ho_wo_k_grid_desc, + // make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + // make_pass_through_transform(K)), + // make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}), + // make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1, + wei_grid_desc_gemmk0_gemmn_gemmk1, + out_grid_desc_gemmm_gemmn); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..381f9ac9d6f5a9b000eb44839c7198e4dc9f5d9b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ebfaabb03ebdda41d086e4b80c5856156810992a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e576d69f5f4b5a9a60a05f84edb43fec990e91a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13e1bf251abfdc3716bdd17f331f82fd7bdfad41 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..088d14b2ee472bc049385ad27e45d02df1e3b6c9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = N * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6785d56df7e70fdb4882f0c431b886d8c87caac --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM0 = 1 +// GemmM1 = K +// GemmN0 = N0 +// GemmN1 = (N / N0) * Ho * Wo +// GemmK0 = (C / C0) * Y * X +// GemmK1 = C0 +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const N0Type& N0, + const C0Type& C0) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto N1 = N / N0; + const auto C1 = C / C0; + + // weight tensor + const auto wei_gk0_gm0_gm1_gk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_unmerge_transform(make_tuple(C0, C1 * Y * X))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(C0, C1)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); + + const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor( + in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C1, Y, X)), + make_pass_through_transform(N0), + make_merge_transform(make_tuple(N1, Ho, Wo)), + make_pass_through_transform(C0)), + make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_n_k_howo_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)); + + const auto out_n0_n1_1_k_howo_grid_desc = + transform_tensor_descriptor(out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor( + out_n0_n1_1_k_howo_grid_desc, + make_tuple(make_pass_through_transform(I1), + make_pass_through_transform(K), + make_pass_through_transform(N0), + make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), + make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return make_tuple( + wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/stream_config.hpp b/3rdparty/composable_kernel/include/ck/stream_config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70ca34555a01436c79ba244ac03572bb4e9520b4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/stream_config.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; +}; diff --git a/3rdparty/composable_kernel/include/ck/tensor/static_tensor.hpp b/3rdparty/composable_kernel/include/ck/tensor/static_tensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fee679f91060aca623ab498362de78fd8118fee2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor/static_tensor.hpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATIC_TENSOR_HPP +#define CK_STATIC_TENSOR_HPP + +namespace ck { + +// StaticTensor for Scalar +template ::type = false> +struct StaticTensor +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} + + __host__ __device__ constexpr StaticTensor(T invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // read access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const T& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // write access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr T& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + StaticBuffer data_; + static constexpr T zero_scalar_value_ = T{0}; + const T invalid_element_scalar_value_; + T ignored_element_scalar_; +}; + +// StaticTensor for vector +template ::type = false> +struct StaticTensorTupleOfVectorBuffer +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + static constexpr index_t num_of_vector_ = + math::integer_divide_ceil(element_space_size_, ScalarPerVector); + + using V = vector_type; + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() + : invalid_element_scalar_value_{0} + { + } + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // Get S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const S& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // Set S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr S& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + // Get X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr X GetAsType(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_.template GetAsType(Number{}); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + // TODO: is this right way to initialize a vector? + return X{0}; + } + else + { + // TODO: is this right way to initialize a vector? + return X{invalid_element_scalar_value_}; + } + } + } + + // Set X + // Idx is for S, not X. Idx should be aligned with X + template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr void SetAsType(Idx, X x) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + data_.template SetAsType(Number{}, x); + } + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr V& GetVectorTypeReference(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + StaticBufferTupleOfVector data_; + static constexpr S zero_scalar_value_ = S{0}; + const S invalid_element_scalar_value_ = S{0}; + S ignored_element_scalar_; +}; + +template ::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc) +{ + return StaticTensor{}; +} + +template < + AddressSpaceEnum AddressSpace, + typename T, + typename TensorDesc, + typename X, + typename enable_if::type = false, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value) +{ + return StaticTensor{invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/cluster_descriptor.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/cluster_descriptor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0c9ea2ff2a0d73b793008a954eaf7293b33ade08 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/cluster_descriptor.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::Size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = Sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e4d7593e9083f321f1cabe65faaee3c14259666 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform.hpp @@ -0,0 +1,1954 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/multi_index.hpp" + +namespace ck { + +template +struct PassThrough +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr PassThrough() = default; + + __host__ __device__ constexpr PassThrough(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("PassThrough, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Pad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr Pad() = default; + + __host__ __device__ constexpr Pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || + ((idx_up[Number<0>{}] >= left_pad_length_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_)); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Pad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length %d", index_t{left_pad_length_}); + printf("right_pad_length %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +template +struct LeftPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + + __host__ __device__ constexpr LeftPad() = default; + + __host__ __device__ constexpr LeftPad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("LeftPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length_ %d", index_t{left_pad_length_}); + printf("}"); + } +}; + +template +struct RightPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr RightPad() = default; + + __host__ __device__ constexpr RightPad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, + low_length_{low_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("RightPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("low_length_ %d", index_t{low_length_}); + printf("left_pad_length_ %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of Number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct Embed +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + UpLengths up_lengths_; + Coefficients coefficients_; + + __host__ __device__ constexpr Embed() = default; + + __host__ __device__ constexpr Embed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && + LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Embed, "); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("coefficients_ "); + print_multi_index(coefficients_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses regular to do lowering of +// multi-index and use carry-and-borrow check to do lowering of multi-index delta +template +struct Merge_v1_carry_check +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v1_carry_check() = default; + + __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // normal division + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]); + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] - borrow; + + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t negative_idx_low_tmp = borrow - idx_low[i]; + + bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at run-time each time this function is called, and can be + // very expensive. + LowerIndex idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + bool do_carry = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] + do_carry; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_carry = idx_low_tmp >= low_lengths_[i]; + +#if 0 + // TODO: use exec-mask inline asm, which use 1 VALU + if(do_carry) + { + idx_diff_low(i) -= low_lengths_[i]; + } +#elif 1 + // this use 2 VALU + idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i]; +#elif 1 + // this use 2 VALU + index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i]; + idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry; + + idx_low(I0) += idx_diff_low[I0]; + } + else if constexpr(Hack == 2) + { + // do borrow check on each low dimension in reversed order + // do not need to check the first dimension + bool do_borrow = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] - do_borrow; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_borrow = idx_low_tmp < 0; + +#if 0 + // TODO: use exec-mask inline asm + if(do_borrow) + { + idx_diff_low(i) += low_lengths_[i]; + } +#elif 1 + idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i]; +#elif 1 + index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i]; + idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow; + + idx_low(I0) += idx_diff_low[I0]; + } + else + { + // not implemented + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { +#if 1 + UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#elif 0 + UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#else + UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#endif + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v1_carry_check, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_shift +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicShift(LowLengths{}[i]); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsMagicDivisorMultipiler = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; + LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, + Number{})}, + low_lengths_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(Number<0>{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{}); + + idx_low(Number<0>{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_magic_divisor_multiplier_); + printf("low_lengths_magic_divisor_shift_ "); + print_multi_index(low_lengths_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2r2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsScanMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_; + LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2r2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + low_lengths_scan_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); }, + Number{})}, + low_lengths_scan_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + index_t idx_low_old = idx_low[i]; + + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + idx_diff_low(i) = idx_low[i] - idx_low_old; + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_diff_low(Number{}) = tmp - idx_low[Number{}]; + + idx_low(Number{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2r2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan "); + print_multi_index(low_lengths_scan_); + printf("low_lengths_scan_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_scan_magic_divisor_multiplier_); + printf("low_lengths_scan_magic_divisor_shift_ "); + print_multi_index(low_lengths_scan_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v3_division_mod +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v3_division_mod() = default; + + __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct UnMerge +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + __host__ __device__ constexpr UnMerge() = default; + + __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(Number<0>{}) = + (0x00ffffff & idx_low[Number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + CalculateLowerIndex(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("UnMerge, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("up_lengths_scan_"); + print_multi_index(up_lengths_scan_); + printf("}"); + } +}; + +template +struct Freeze +{ + LowerIndex low_idx_; + + __host__ __device__ constexpr Freeze() = default; + + __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = low_idx_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */, + Number) + { + idx_diff_low(Number<0>{}) = 0; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Freeze"); + printf("low_idx_ %d", index_t{low_idx_}); + } +}; + +// Insert a dangling upper dimension without lower dimension +template +struct Insert +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr Insert() = default; + + __host__ __device__ constexpr Insert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + __host__ __device__ static void + UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number) + { + static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Insert"); + print_multi_index(up_lengths_); + } +}; + +template +struct Vectorize +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + + UpLengths up_lengths_; + VectorSize vector_size_; + + __host__ __device__ constexpr Vectorize() = default; + + __host__ __device__ constexpr Vectorize(const VectorSize& vector_size, + const UpLength& up_length) + : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Vectorize, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Slice +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + __host__ __device__ constexpr Slice() = default; + + __host__ __device__ constexpr Slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Slice, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("slice_begin_ %d", index_t{slice_begin_}); + printf("slice_end %d", index_t{slice_end_}); + printf("}"); + } +}; + +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct Modulo +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Modulo() = default; + + __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + const auto idx_low_old = idx_low; + idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_; + idx_diff_low(I0) = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Modulus, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform_helper.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..044a90370095eb53b94b5c2fba81abdbeae82c00 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return PassThrough{low_length}; +} + +template +__host__ __device__ constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return Pad{low_length, left_pad, right_pad}; +} + +template +__host__ __device__ constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPadLength& left_pad, + integral_constant = integral_constant{}) +{ + return LeftPad{low_length, left_pad}; +} + +template +__host__ __device__ constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPadLength& right_pad, + integral_constant = integral_constant{}) +{ + return RightPad{low_length, right_pad}; +} + +template ::type = false> +__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return Embed{up_lengths, coefficients}; +} + +template +__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ +#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return make_merge_transform_v2_magic_division(low_lengths); +#else + return make_merge_transform_v1_carry_check(low_lengths); +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v1_carry_check(const LowLengths& low_lengths) +{ + return Merge_v1_carry_check{low_lengths}; +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ +#if 1 + return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return Merge_v3_division_mod{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return UnMerge{up_lengths}; +} + +template +__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return Freeze{low_idx}; +} + +template +__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return Insert{up_idx}; +} + +template +__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return Slice{low_length, slice_begin, slice_end}; +} + +template +__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, + const UpLength& up_length) +{ + return Vectorize{vector_size, up_length}; +} + +template +__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return Modulo{modulus, up_length}; +} +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/tensor_adaptor.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_adaptor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d42e0a6ff08f60eb1cf6e0f241f93e528bf3514b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_adaptor.hpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct TensorAdaptor +{ + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() + { + return LowerDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss() + { + return UpperDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetTopDimensionHiddenIds() + { + return TopDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() + { + return BottomDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_top = Number{}; + + constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + __host__ __device__ static constexpr index_t GetNumOfBottomDimension() + { + return BottomDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfTopDimension() + { + return TopDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension(); + constexpr static index_t ndim_top_ = GetNumOfTopDimension(); + + using HiddenIndex = MultiIndex; + using BottomIndex = MultiIndex; + using TopIndex = MultiIndex; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: +#if 0 // workaround compiler complaint about constexpr + __host__ __device__ constexpr TensorAdaptor() = default; +#else + __host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {} +#endif + + __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionHiddenIdss::Size() == ntransform_ && + UpperDimensionHiddenIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + +#if 0 // debug + template + __host__ __device__ constexpr index_t GetTopDimensionLength(Number idim) const + { + // TODO: not implemented + } + + template + __host__ __device__ constexpr index_t GetBottomDimensionLength(Number idim) const + { + // TODO: not implemented + } +#endif + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = GetNumOfTransform(); + constexpr index_t ndim_hidden = GetNumOfHiddenDimension(); + + MultiIndex idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = GetTransforms().At(itran); + constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran); + constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorAdaptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionHiddenIds:"); + LowerDimensionHiddenIdss{}.At(i).Print(); + printf("UpperDimensionHiddenIds:"); + UpperDimensionHiddenIdss{}.At(i).Print(); + }); + + printf("BottomDimensionHiddenIds:"); + BottomDimensionHiddenIds::Print(); + printf("TopDimensionHiddenIds:"); + TopDimensionHiddenIds::Print(); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +template +__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::GetNumOfTopDimension() == + TensorAdaptor1::GetNumOfBottomDimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::GetBottomDimensionHiddenIds()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + math::min(adaptor1_min_hidden_id_, + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::GetTopDimensionHiddenIds() + Number{}; + + // put everything together + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms}; +} + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::Size(); + + static_assert(LowerDimensionOldTopIdss::Size() == ntransform && + UpperDimensionNewTopIdss::Size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number{}; }, + Number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number{}; + + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f07d5b1733d9cc96dfac9f18cbbe30510760cabc --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor.hpp @@ -0,0 +1,615 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" + +namespace ck { + +template +struct TensorCoordinate; + +template +struct TensorCoordinateStep; + +// Transforms: Tuple +// LowerDimensionIdss : Tuple, ...> +// UpperDimensionIdss : Tuple, ...> +// VisibleDimensionIds> : Sequence<...> +template +struct TensorDescriptor +{ + // TODO make these private + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ static constexpr index_t GetNumOfVisibleDimension() + { + return VisibleDimensionIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_visible) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_visible = Number{}; + + constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + + using VisibleIndex = MultiIndex; + using HiddenIndex = MultiIndex; + using Coordinate = TensorCoordinate; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: +#if 0 // workaround compiler complaint about constexpr + __host__ __device__ constexpr TensorDescriptor() = default; +#else + __host__ __device__ constexpr TensorDescriptor() + : transforms_{}, element_size_{}, element_space_size_{} + { + } +#endif + + __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : transforms_{transforms}, + element_size_{InitializeElementSize(transforms)}, + element_space_size_{element_space_size} + + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionIdss::Size() == ntransform_ && + UpperDimensionIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ static constexpr index_t GetNumOfDimension() + { + return GetNumOfVisibleDimension(); + } + + template + __host__ __device__ constexpr auto GetLength(Number) const + { + static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); + + constexpr auto tmp = GetTransformAndItsUpperDimension(Number{}); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[Number{}].GetUpperLengths()[Number{}]; + } + + __host__ __device__ constexpr auto GetLengths() const + { + // FIXME: use Tuple of reference instead + return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number{}); + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } + + template + __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + { + static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); + + return make_tensor_coordinate(*this, idx).GetOffset(); + } + + // TODO make these private + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionIdss() + { + return LowerDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionIdss() + { + return UpperDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetVisibleDimensionIds() + { + return VisibleDimensionIds{}; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorDescriptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionIds:"); + LowerDimensionIdss{}.At(i).Print(); + printf("UpperDimensionIds:"); + UpperDimensionIdss{}.At(i).Print(); + }); + printf("}"); + + VisibleDimensionIds::Print(); + } + + // TODO make these private + Transforms transforms_; + ElementSize element_size_; + ElementSpaceSize element_space_size_; +}; + +template +struct TensorCoordinate +{ + // TODO make these private + static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); + + using HiddenIndex = MultiIndex; + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinate() = default; + + __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + __host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } + + __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } + + // TODO make these private + __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + + __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + + __host__ __device__ constexpr auto GetVisibleIndex() const + { + return get_container_subset(idx_hidden_, VisibleDimensionIds{}); + } + + // TODO make these private + HiddenIndex idx_hidden_; +}; + +template +struct TensorCoordinateStep +{ + // TODO make these private + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinateStep() = default; + + __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, + const MultiIndex& do_transforms) + : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} + { + } + + __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } + + // TODO make these private + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + { + return idx_diff_visible_; + } + + VisibleIndex idx_diff_visible_; + MultiIndex do_transforms_; + + // HACK: control UpdateLowerIndex() + static constexpr UpdateLowerIndexHack update_lower_index_hack_; +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_tensor_descriptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return Number{}; + } +}; + +template +__host__ __device__ constexpr auto +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + // sanity check + { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_visible_ids) constexpr { + return transform_sequences( + // convert lower dimension visible id to hidden id + [](auto low_dim_visible_id) constexpr { + return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; + }, + low_dim_visible_ids); + }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr index_t num_new_transform = NewTransforms::Size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, Number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + Number{}); + + // new visible dimension's hidden ids + constexpr auto unordered_new_visible_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_visible_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + constexpr auto new_visible_dim_hidden_ids = + unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); + + // put everything together + const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + + const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + MultiIndex idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, visible_dim_ids, idx_visible); + + // calculate hidden index + static_for{}([&tensor_desc, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return TensorCoordinate{idx_hidden}; +} + +// UpdateLowerIndexHack: Sequence<...> +// HACK: control UpdateLowerIndex +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible, + UpdateLowerIndexHack) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!"); + + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + MultiIndex non_zero_diff_pick_visible; + + static_for<0, ndim_visible, 1>{}( + [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); + + set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible); + + static_for{}([&](auto itran) { + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + MultiIndex non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.Size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + + return TensorCoordinateStep{idx_diff_visible, + do_transforms}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible) +{ + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + return make_tensor_coordinate_step( + TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, + TensorCoord& coord, + const TensorCoordStep& coord_step) +{ + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize visible index diff + set_container_subset( + idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); + + // this is what needs to be updated + auto& idx_hidden = coord.GetHiddenIndex(); + + // update visible index + auto idx_hidden_pick_visible = + get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); + + idx_hidden_pick_visible += coord_step.GetIndexDiff(); + + set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(coord_step.do_transforms_[itran]) + { + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + MultiIndex idx_diff_low; + + // HACK: control UpdateLowerIndex for Merge using hack + constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); + + tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); +} + +template +__host__ __device__ constexpr bool +coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + const auto& idx_hidden = coord.GetHiddenIndex(); + + static_for{}([&tensor_desc, &idx_hidden, &valid](auto itran) { + const auto tran = tensor_desc.GetTransforms().At(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) + { + const auto idx_up = + get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); + } + }); + + return valid; +} + +template +__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + // check visible index + const auto& idx_visible = coord.GetVisibleIndex(); + + bool is_visible_index_valid = true; + + static_for<0, TensorDesc::GetNumOfDimension(), 1>{}( + [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) { + is_visible_index_valid = + is_visible_index_valid && + (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); + }); + + // check other hidden index + return is_visible_index_valid && + coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord); +} + +template +using TensorCoordinate_t = decltype(make_tensor_coordinate( + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); + +template +using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor_helper.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..461aae72cf7b1879c90e473adb3814e1c4875b52 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +/* + * These functions create tensor descriptor at runtime. If they are not constexpr, you will + * likely see usage of scratch memory during construction of these tensor descriptors. So + * it's better to call these functions on host and then pass the constructed tensor descritpors + * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these + * functions on GPU without worrying about scratch memory usage. + */ + +#if CK_WORKAROUND_SWDEV_275126 +template +__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + Number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::Size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } +} +#endif + +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> +template ::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple& lengths, + const Tuple& strides) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + +#if !CK_WORKAROUND_SWDEV_275126 + // rocm-4.1 compiler would crash for recursive labmda + // recursive function for reduction + auto f = [&](auto fs, auto i, auto acc_old) { + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < N - 1) + { + return fs(fs, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } + }; + + const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); +#else + const auto element_space_size = + calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); +#endif + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_packed(const Tuple& lengths) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) Number<> +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) +{ + constexpr auto I1 = Number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return Number{}; + } + else + { + return container_reduce(lengths, + math::multiplies{}, + Number{}, + i + I1, + Number{}, + I1); + } + }, + Number{}); + + return make_naive_tensor_descriptor(lengths, strides); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_description/tensor_space_filling_curve.hpp b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_space_filling_curve.hpp new file mode 100644 index 0000000000000000000000000000000000000000..17c9100b9fd76418d562c5c175ad24062d8c3415 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/statically_indexed_array_multi_index.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template // # of scalars per access in each dimension +struct SpaceFillingCurve +{ + static constexpr index_t nDim = TensorLengths::Size(); + + using Index = MultiIndex; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr index_t GetNumOfAccess() + { + static_assert(TensorLengths::Size() == ScalarsPerAccess::Size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / + ScalarPerVector; + } + + template + static __device__ __host__ constexpr auto GetStepBetween(Number, + Number) + { + static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); + static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); + + constexpr auto idx_begin = GetIndex(Number{}); + constexpr auto idx_end = GetIndex(Number{}); + return idx_end - idx_begin; + } + + template + static __device__ __host__ constexpr auto GetForwardStep(Number) + { + static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0"); + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr auto GetBackwardStep(Number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr Index GetIndex(Number) + { +#if 0 + /* + * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number{})); +#else + + constexpr auto access_strides = container_reverse_exclusive_scan( + ordered_access_lengths, math::multiplies{}, Number<1>{}); + + constexpr auto idx_1d = Number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return Number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, Number{}); +#endif + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = + !SnakeCurved || forward_sweep[idim] + ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number) + { + constexpr auto idx = GetIndex(Number{}); + + return generate_tuple([&](auto i) { return Number{}; }, Number{}); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b1b7be11ef7b3cd19f11bd400b8a96146921bf7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp" + +namespace ck { + +// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 +// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template + typename BM10BN10ThreadClusterBN10Xs, // Sequence + index_t AThreadCopyScalarPerVector_BM11, + index_t BThreadCopyScalarPerVector_BN11, + typename enable_if::type = false> +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); + + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; + + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; + + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; + + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; + + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; + + __host__ __device__ static constexpr auto + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) + { + const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_block_bk0_bm0_bm1_bk1; + } + + __host__ __device__ static constexpr auto + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) + { + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_block_desc_bk0_bn0_bn1_bk1; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; + } + + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() + { + return Sequence{}; + } + + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); + + public: + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} + { + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); + + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO remove this restriction + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + + // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); + static_assert(BM0 == 2 && BN0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) + { + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); + + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + template + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); + + constexpr auto threadwise_contraction = + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of bk0 + static_for{}([&](auto bk0) { + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33120bd86ff01da47f7b80c593bc966785cb3711 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AKMBlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BKNBlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template < + index_t BlockSize, + typename FloatA, + typename FloatB, + typename FloatC, + typename AKMBlockDesc, + typename BKNBlockDesc, + index_t M1PerThreadM11, + index_t N1PerThreadN11, + index_t KPerThread, + index_t M1N1ThreadClusterM100, + index_t M1N1ThreadClusterN100, + index_t M1N1ThreadClusterM101, + index_t M1N1ThreadClusterN101, + index_t AThreadCopyScalarPerVector_M11, + index_t BThreadCopyScalarPerVector_N11, + typename enable_if::type = false> +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K = AKMBlockDesc{}.GetLength(I0); + static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); + static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */) + { + const auto a_k_m0_m1_block_desc = transform_tensor_descriptor( + AKMBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */) + { + const auto b_k_n0_n1_block_desc = transform_tensor_descriptor( + BKNBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); + static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); + + public: + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } + + __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k_m0_m1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k_n0_n1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K, M0, M1] + static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f45655721fe4de11b50b91e0dc5d22790ff73bea --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0); + static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1); + static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2); + + static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0); + static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); + static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); + + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number<1>{}, + Number{}, + Number{}, + Number{})); + + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)} + { + static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() && + BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) && + ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4), + "wrong! E dimension not consistent\n"); + + static_assert(E1 % EPerThreadLoop == 0, ""); + static_assert(KPerThread % KPerThreadLoop == 0, ""); + + static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 && + WoPerBlock % WoPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = KPerBlock / KPerThread; + constexpr auto HThreadCluster = HoPerBlock / HoPerThread; + constexpr auto WThreadCluster = WoPerBlock / WoPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths() + { + return Sequence{}; + } + + __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id) + { + constexpr auto K0 = KPerBlock / KPerThread; + constexpr auto N0 = I1; + constexpr auto H0 = HoPerBlock / HoPerThread; + constexpr auto W0 = WoPerBlock / WoPerThread; + + constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_k_n_h_w_thread_cluster_idx = + c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex( + make_multi_index(thread_id)); + + return c_k_n_h_w_thread_cluster_idx; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin, I0), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0, I0), + a_thread_buf); + + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, I0, I0, I0), + c_thread_buf, + make_tuple(k_begin, I0, I0, I0)); + }); + }); + } + + template + __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx); + } + + private: + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + E2, + E2>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d0e3be41bb6e0450ecc6e7daeef711f0a9cdf8d2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -0,0 +1,1000 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +} + +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro +// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in +// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the +// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 +template +struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + using Base::a_block_desc_m0_m1_m2_k; + using Base::A_K1; + using Base::b_block_desc_n0_n1_n2_k; + using Base::B_K1; + using Base::c_thread_buf_; + using Base::c_thread_desc_; + using Base::CalculateAThreadOriginDataIndex; + using Base::CalculateBThreadOriginDataIndex; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + + // 2-wave optimized blockwise gemm + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, k), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, k), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + __builtin_amdgcn_s_barrier(); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except + // the first, as we can shorten non-MAC cluster a bit and there's no observable negative + // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids + // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the + // chance of latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { + asm volatile("s_barrier" ::); + __builtin_amdgcn_s_barrier(); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_s_barrier(); + block_sync_lds(); + __builtin_amdgcn_s_barrier(); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_s_barrier(); + //__builtin_amdgcn_s_setprio(1); + asm("S_SETPRIO 1"); + __builtin_amdgcn_s_barrier(); + } + }); + }); + }); + __builtin_amdgcn_s_barrier(); + //__builtin_amdgcn_s_setprio(0); + asm("S_SETPRIO 0"); + __builtin_amdgcn_s_barrier(); + }); + } + + protected: + // A[M0, M1, M2, KPerInnerLoop] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, KPerInnerLoop] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + +#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +}; + +template +constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } +}; + +// Blockwise gemm supporting +// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4 +// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS +// source buffer +// 3. configurable k index starting position and step size after each FMA/XDL instruction +template {}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) + : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) + { + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa814ab00939d0c36866af01b6cf2057dcbd5121 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t KPerBlock = K0PerBlock * KPack; + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0K0BN0N1N2N3K1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ void MoveABlockSliceWindow() + { + a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k, + make_multi_index(0, 0, 0, K0PerBlock * KPack)); + } + __device__ void ResetABlockStartWindow() + { + a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex()); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + constexpr index_t k0 = k / KPack; + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + private: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, // KPerThread + Number{}, // repeat + Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d7ec177365abf88b6f5ee7218f3cb0b5d5ffab0b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" + +namespace ck { + +template +struct BlockwiseSoftmax +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); + static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); + + using ThreadSliceDesc_M = decltype( + make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); + + using ThreadwiseMaxReduce = typename conditional< + IgnoreNaN, + ThreadwiseReduction>, + ThreadwiseReduction>::type; + + using ThreadwiseSumReduce = typename conditional< + IgnoreNaN, + ThreadwiseReduction>, + ThreadwiseReduction>::type; + + using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); + + using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2; + + using BufferType = StaticBuffer; + + template + __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf) + { + // find max value + static_for<0, MRepeat, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); + block_sync_lds(); + }); + + // calculate exp for elements, P=exp(s-max) + static_for<0, MRepeat, 1>{}([&](auto iM) { + static_for<0, KRepeat, 1>{}([&](auto iK) { + auto offset = Number{}; + in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) + ? 0 + : math::exp(in_thread_buf[offset] - max_value_buf(iM)); + }); + }); + + // sum data + static_for<0, MRepeat, 1>{}([&](auto I) { + sum_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I)); + block_sync_lds(); + }); + } + + BufferType max_value_buf; + BufferType sum_value_buf; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03e4d42d3a1f9eb1b180c95368b905e619e67110 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v5r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_step_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v5r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..316508651e4bd4dad2e47edf020a4145176fa753 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/reduction_common.hpp" + +namespace ck { + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace +// 2) work_buffer has T elements, and space size is no less than 3*BlockSize +// 3) mean_value, var_value and count is the input data in vgpr from each thread +// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread +// 5) Merge mean and M from ThreadwiseWelford +// clang-format on +template +struct BlockwiseWelford +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + __device__ static inline void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + __device__ static void Run(T& mean_value, T& var_value, int& count) + { + __shared__ T mean_block_buf[BlockSize]; + __shared__ T var_block_buf[BlockSize]; + __shared__ int count_block_buf[BlockSize]; + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + + mean_block_buf[offset1] = mean_value; + var_block_buf[offset1] = var_value; + count_block_buf[offset1] = count; + + block_sync_lds(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + int count1 = count_block_buf[offset1]; + + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + int count2 = count_block_buf[offset2]; + + Merge(mean1, var1, count1, mean2, var2, count2); + + mean_block_buf[offset1] = mean1; + var_block_buf[offset1] = var1; + count_block_buf[offset1] = count1; + } + + block_sync_lds(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + count = count_block_buf[offset]; + mean_value = mean_block_buf[offset]; + + if constexpr(GetActualVariance) + var_value = var_block_buf[offset] / count; + else + var_value = var_block_buf[offset]; + }; +}; +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2163ad32383d4f940ab4cf3b041ab2d5c61456cc --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize +// 3) in_out_value is the input data in vgpr from each thread +// 4) in_out_value is the over-written reduced output in vgpr for each thread +// clang-format on +template > +struct PartitionedBlockwiseReduction +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize +// 3) in_out_value is the input data in vgpr from each thread +// 4) in_out_value is the over-written reduced output in vgpr for each thread +// clang-format on +template > +struct PartitionedBlockwiseReduction_v2 +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = ThreadClusterDesc{}; + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize +// 3) in_out_value/in_out_index is the input data in vgpr from each thread +// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread +// clang-format on +template < + typename AccDataType, + typename IndexDataType, + index_t BlockSize, + typename ThreadClusterLengths_M_K, + typename ThreadClusterArrangeOrder, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct PartitionedBlockwiseReductionWithIndex +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + // This interface accumulates on both data values and indices + template + __device__ static void Reduce(BufferType& work_val_buffer, + IdxBufferType& work_idx_buffer, + AccDataType& in_out_value, + IndexDataType& in_out_index) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + static_assert(is_same{}, + "Buffer data type should be consistent as IndexDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << I(); + + if(thread_k_cluster_id % (indOffset * 2) == 0) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_val_buffer[offset1]; + AccDataType opData2 = work_val_buffer[offset2]; + IndexDataType currIndex1 = work_idx_buffer[offset1]; + IndexDataType currIndex2 = work_idx_buffer[offset2]; + + Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); + work_val_buffer(offset1) = opData1; + work_idx_buffer(offset1) = currIndex1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_val_buffer[offset]; + in_out_index = work_idx_buffer[offset]; + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0e5dfb355fb00565f8d9834a2e0df05bfd9db416 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c47a49b38b6a30eebc9190cc338d4dbc0bc8524 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa33fc083f15cfd8904c8fa086a866dbc7817e7c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. It does not keep reference to tensor descriptor +// 3. Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eb5f589a4ada976c9be4a5001fb6fc288c7e9a43 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + src2_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc2SliceOrigin( + src2_desc, src2_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run( + src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3bd7806389b59542b191e61f5369aebf7300fc6a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs); + } + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a4a29f5d5edc3b840fc7489d5f706231990cad70 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardDataSpecialization +{ + Default, + Filter1x1Stride1Pad0, +}; + +inline std::string +getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization& s) +{ + switch(s) + { + case ConvolutionBackwardDataSpecialization::Default: return "Default"; + case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: + return "FFilter1x1Stride1Pad0"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20b2a152b9d276b0bd0f6e1aa5798b4817b013a9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardWeightSpecialization +{ + Default, + Filter1x1Stride1Pad0, + Filter1x1Pad0, + OddC, +}; + +inline std::string +getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s) +{ + switch(s) + { + case ConvolutionBackwardWeightSpecialization::Default: return "Default"; + case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0: + return "Filter1x1Stride1Pad0"; + case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionBackwardWeightSpecialization::OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..953ff1e06ed07b24968c7f5c0842161ac66643ed --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionForwardSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + OddC, +}; + +inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) +{ + switch(s) + { + case ConvolutionForwardSpecialization::Default: return "Default"; + case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization::OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_base.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_base.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8137098eac6a10763e2466e6dc4be7cc6dbf4c4c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/stream_config.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BaseArgument +{ + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; + BaseArgument& operator=(const BaseArgument&) = default; + + virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; +}; + +struct BaseInvoker +{ + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; + BaseInvoker& operator=(const BaseInvoker&) = default; + + virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) + { + return float{0}; + } + + virtual ~BaseInvoker() {} +}; + +struct BaseOperator +{ + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; + BaseOperator& operator=(const BaseOperator&) = default; + + virtual bool IsSupportedArgument(const BaseArgument*) { return false; } + virtual std::string GetTypeString() const { return ""; } + + virtual std::string GetTypeIdName() const { return typeid(*this).name(); } + + virtual std::string GetTypeIdHashCode() const + { + std::ostringstream oss; + + oss << std::hex << typeid(*this).hash_code(); + + return oss.str(); + }; + + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + + virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } + + virtual ~BaseOperator() {} +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fcd893c7a8c38d31f480b13eb8e7bfc97ec7fef --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceBatchedContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e755913280f73e325daf1352324bab8d3df8a3a9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acd779b2db34b98c389727b7a369f3808f5c6c1d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp @@ -0,0 +1,50 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BatchedGemmEPermuteDesc +{ + ck::index_t G0_, G1_, M_, N_; + ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; +}; + +template +struct DeviceBatchedGemmEPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af681127f30404f182475039add4443e46213398 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..116e62c00907e0f72031552f46b8a3ea975a18f5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultiD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eacc5976d3ea3885a8294f1ee74a249f0a3d52a7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a0, + const void* p_b0, + std::array p_d0s, + const void* p_b1, + std::array p_d1s, + void* p_e1, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA0, + ck::index_t StrideB0, + std::array StrideD0s, + ck::index_t StrideB1, + std::array StrideD1s, + ck::index_t StrideE1, + ck::index_t BatchStrideA0, + ck::index_t BatchStrideB0, + std::array BatchStrideD0s, + ck::index_t BatchStrideB1, + std::array BatchStrideD1s, + ck::index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f85e575ce4b2e0d70c05010bfca6e9ed9b84a8 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template // TODO: enum for mask type +struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff5551998015d57f1386a263308606768af2a3d9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator +{ + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d39f3b7cbcfc9ccaccca5f755be747cc4896dc1b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormBwd : public BaseOperator +{ + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + virtual std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormBwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa93dd9c19d77fbee7ccdafd1e41e4c4f64735cd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwd : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double exponentialAverageFactor, + void* resultRunningMean, + void* resultRunningVariance) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a00fd9db3327a7c536b1ada0ca090bfa847a256 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormInfer : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + const void* estimatedMean, + const void* estimatedInvVariance, + void* p_y) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormInferPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_cgemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aedae53800b03cb78e6af5b15c8c6dfcd9792eea --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceCGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC) = 0; +}; + +template +using DeviceCGemmPtr = std::unique_ptr< + DeviceCGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dbc525c099bca293d106ac6d3c37ac72b2d749cc --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..82054a3c9423f819d17a003b8c22844f0a2273fe --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b9881088dde07a62e92c5f77a19edc9c4b7670f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a627deeb2221f5e271532d45bc8a544c8cceeba --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivation : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc139303c929ae999d182ad09e9a4ef33f5209e1 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + const void* p_resi, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..23aada0f448c81c83ea2d06071161ff81f3b4206 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceElementwise : public DeviceElementwiseBase +{ + static constexpr index_t NumDim = NumDim_m + NumDim_n; + + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_MN_2d(Desc_MN desc_mn, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n) + { + std::ignore = blockSize; + std::ignore = gridSize; + const auto m = desc_mn.GetLength(I0); + const auto n = desc_mn.GetLength(I1); + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; + const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; + + const auto desc_mn_pad = transform_tensor_descriptor( + desc_mn, + make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return desc_mn_pad; + } + + static auto MakeDescriptor_MN(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type(); + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type(); + + const auto mLengths = get_container_subset(tupleOfShape, mDimIds); + const auto nLengths = get_container_subset(tupleOfShape, nDimIds); + + // merge nd to 2d desc - [s0 * s1 * ...] + + if constexpr(NumDim > 2) + { + const auto desc_mn = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n); + } + else + return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n); + } + + template + static auto GenerateInOutGrid2dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 2) + { + return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1); + } + else + { + return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1); + }; + }, + Number{}); + }; + + using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); + using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_2D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + blockSize_(256), + gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future + num_threads_m_((gridSize_ * blockSize_) / 16), + num_threads_n_(16) + { + static_assert(NumDim_m > 0, ""); + static_assert(NumDim_n > 0, ""); + + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + + in_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(lengths, + inStridesArray[I.value], + gridSize_, + blockSize_, + num_threads_m_, + num_threads_n_); + }, + Number{}); + + out_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(lengths, + outStridesArray[I.value], + gridSize_, + blockSize_, + num_threads_m_, + num_threads_n_); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + InGrid2dDescTuple in_grid_2d_desc_tuple_; + OutGrid2dDescTuple out_grid_2d_desc_tuple_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + index_t blockSize_; + index_t gridSize_; + index_t num_threads_m_; + index_t num_threads_n_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_elementwise_2d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.in_grid_2d_desc_tuple_, + arg.out_grid_2d_desc_tuple_, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + arg.num_threads_m_, + arg.num_threads_n_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t vectorDim) { + if(strides[vectorDim] == 1 && + (lengths[vectorDim] % scalarPerVector == 0 || + lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) + { + return true; + } + if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim]) + { + return true; + } + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid(pArg->lengths_, + pArg->inStridesArray_[I.value], + InScalarPerVectorSeq::At(I), + NumDim_m - 1)) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid(pArg->lengths_, + pArg->outStridesArray_[I.value], + OutScalarPerVectorSeq::At(I), + NumDim - 1)) + valid = false; + }); + + return valid; + }; + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_base.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_base.hpp new file mode 100644 index 0000000000000000000000000000000000000000..728faf543df6cdd5a6593224ee0bb1801b22cd06 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_base.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseBase : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +template +using DeviceElementwiseBasePtr = std::unique_ptr< + DeviceElementwiseBase>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42f821ff2ebfd3bfc9940e4089aaca93790e7a29 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseNormalization : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccDataType epsilon, + const std::array in_dev_buffers, + const void* p_gamma, + const void* p_beta, + void* p_y, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceElementwiseNormalizationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c0af6f80faf606a22678e516be994a75c1d56eca --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c2161eaed5b9c7d7913686f44c360c044ca3b07 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DEGridDesc_M0_M1_M2_N0_N1 +{ + ck::index_t M0_, M1_, M2_, N0_, N1_; + ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; +}; + +// input : A[M, K], B[K, N], +// input : D[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D) +template +struct DeviceGemmBiasCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9113bb7b7454443cce4871bf263433d4f907bfba --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4881e32f620be9c31649085813a257e9e84a598 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// FIXME: DeviceGemmReduce type need to well define the problem +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : R0[M], R1[M], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDMultipleR : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmMultipleDMultipleRPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcc088ca43d1b8c6224c0015e3eb7434038af4b2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// FIXME: DeviceGemmReduce type need to well define the problem +template +struct DeviceGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_ops, + std::array reduce_out_element_ops, + ck::index_t BatchCount = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c701bff57f8eb7db155e9abdfd3ab7210e7eeffd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmSplitK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmSplitKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..173c613a325d8c594298ff6751060fecf2f8d453 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ContractionDesc +{ + std::vector a_ms_ks_lengths; + std::vector a_ms_ks_strides; + + std::vector b_ns_ks_lengths; + std::vector b_ns_ks_strides; + + std::array, NumDTensor> ds_ms_ns_lengths; + std::array, NumDTensor> ds_ms_ns_strides; + + std::vector e_ms_ns_lengths; + std::vector e_ms_ns_strides; +}; + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceGroupedContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3350aec8d3174010cf0d926c3a91802ce3b8e7e7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A[G, N, K, Ho, Wo] +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : input image E[G, N, C, Hi, Wi], +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1258aed71c502c659f065d08e8b1a84922c26c83 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedConvBwdWeight : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..644c7ee9a9107718dc609836e6cd8abe7f2dad21 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Convolution Forward: +// input : input image A[G, N, C, Hi, Wi], +// input : weight B[G, K, C, Y, X], +// output : output image E[G, N, K, Ho, Wo] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + const void* p_wei, // weight + void* p_out, // output image + const std::array& in_g_n_c_wis_lengths, + const std::array& in_g_n_c_wis_strides, + const std::array& wei_g_k_c_xs_lengths, + const std::array& wei_g_k_c_xs_strides, + const std::array& out_g_n_k_wos_lengths, + const std::array& out_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const InElementwiseOperation& in_element_op, + const WeiElementwiseOperation& wei_element_op, + const OutElementwiseOperation& out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3bb085c91760f2633b7e4f8e9d1d81644e3d1688 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,959 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx1030__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); + + __shared__ ABDataType p_shared[shared_block_size]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = batch_count; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetEPtrOffset(0); +#endif +} +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + : public DeviceGroupedConvFwdMultipleD +{ + using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_AK0_M_AK1 = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity( + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, e_grid_desc_m_n_)) + { + + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_); + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + std::cout << "num_group: " << num_group_ << std::endl; + + std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl; + std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl; + std::cout << "A[m0, m10, m11, n0, n10, n11]: " << e_grid_desc_m0_m10_m11_n0_n10_n11_ + << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + DefaultBlock2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.num_group_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = kernel_grouped_conv_fwd_dl_multiple_d< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_K0_M0_M1_K1, + DeviceOp::BGridDesc_K0_N0_N1_K1, + DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, + DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, + DefaultBlock2CTileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + has_double_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + return 0; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X + << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad + << " RightPad = " << RightPad << std::endl; + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X + << " LeftPad = " << LeftPad << " RightPad = " << RightPad + << std::endl; + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5)) + { + return false; + } + } + else + { + return false; + } + // check Gridwise GEMM + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acdd381b63191f670c9cf84197f7385fb4522c9b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,837 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_dl( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx1030__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); + + __shared__ ABDataType p_shared[shared_block_size]; + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif +} + +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template < + index_t NDimSpatial, + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd +{ + using DeviceOp = DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeCGridDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(c_g_n_k_wos_lengths, + c_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + using AGridDesc_AK0_M_AK1 = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t({}, {}))>; + using CGridDesc_M_N = remove_cvref_t({}, {}))>; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_c_grid_{static_cast(p_c)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_g_n_k_wos_lengths, + c_g_n_k_wos_strides)}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{ + a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + c_g_n_k_wos_lengths_{c_g_n_k_wos_lengths}, + c_g_n_k_wos_strides_{c_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = c_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity( + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) + { + + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl; + std::cout << "num_group: " << num_group_ << std::endl; + + std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl; + std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl; + std::cout << "A[m0, m10, m11, n0, n10, n11]: " << c_grid_desc_m0_m10_m11_n0_n10_n11_ + << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + DefaultBlock2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array c_g_n_k_wos_lengths_; + std::array c_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0), + arg.c_grid_desc_m_n_.GetLength(I1)) * + arg.num_group_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = + kernel_grouped_conv_fwd_dl; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(!(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X + << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad + << " RightPad = " << RightPad << std::endl; + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X + << " LeftPad = " << LeftPad << " RightPad = " << RightPad + << std::endl; + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of C + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.c_g_n_k_wos_lengths_[2]; + + if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5)) + { + return false; + } + } + else + { + return false; + } + // check Gridwise GEMM + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + return Argument{p_a, + p_b, + p_c, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) override + { + return std::make_unique(p_a, + p_b, + p_c, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e2f81915d92e26e96290f71b9b73c3465ae8130 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Convolution Forward: +// input : input image A[G, N, C, Hi, Wi], +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : output image E[G, N, K, Ho, Wo] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvFwdMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, // input image + const void* p_b, // weight + const std::array& p_ds, + void* p_e, // output image + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..181ee4b428ba236d92cf71bbf33b24e53b2491c0 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -0,0 +1,51 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmDesc +{ + ck::index_t M_, N_, K_; + ck::index_t stride_A_, stride_B_, stride_C_; + + std::vector stride_Ds_; +}; + +template +struct DeviceGroupedGemm : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector>& p_ds, + std::vector& p_e, + std::vector& gemm_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b066a4458518b378313225e24817253922962d29 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator +{ + struct ProblemDesc + { + std::vector a_gs_ms_ks_lengths; + std::vector a_gs_ms_ks_strides; + + std::vector b0_gs_ns_ks_lengths; + std::vector b0_gs_ns_ks_strides; + + std::vector b1_gs_os_ns_lengths; + std::vector b1_gs_os_ns_strides; + + std::vector c_gs_ms_os_lengths; + std::vector c_gs_ms_os_strides; + + std::vector> acc0_biases_gs_ms_ns_lengths; + std::vector> acc0_biases_gs_ms_ns_strides; + + std::vector> acc1_biases_gs_ms_os_lengths; + std::vector> acc1_biases_gs_ms_os_strides; + }; + + virtual std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b0_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..946a757ceeeb39040951e16f7c25540a2932a4df --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -0,0 +1,881 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(group_kernel_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while( + (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_))) + { + if(block_id < arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // per-group batch offset + const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; + const index_t g_idx = __builtin_amdgcn_readfirstlane( + (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run( + arg_ptr[group_id].p_a_grid_ + a_batch_offset, + arg_ptr[group_id].p_b_grid_ + b_batch_offset, + arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, + arg_ptr[group_id].p_c_grid_ + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg_ptr[group_id].block_2_ctile_map_, + arg_ptr[group_id].c0_matrix_mask_); +#else + ignore = group_kernel_args; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle + : public DeviceGroupedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + + using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; + using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute::ProblemDesc; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; + + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); + } + + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; + + using Block2CTileMap = OffsettedBlockToCTileMap; + + struct GroupKernelArg + { + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // batch & stride + index_t num_blocks_per_batch_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // block-to-c-tile map + Block2CTileMap block_2_ctile_map_; + + index_t block_start_, block_end_; + }; + + struct GroupDeviceArg + { + // lengths for the last dimensions of overall problem for sanity check of vector load/store + std::vector raw_lengths_mz_nz_kz_gemm1nz_; + + // strides for the last dimensions of each tensor for sanity check of vector load/store + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + + // for gridwise gemm check + CGridDesc_M_N c_grid_desc_m_n_; + }; + + // Argument + // FIXME: constness + struct Argument : public BaseArgument + { + Argument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op} + { + // TODO ANT: implement bias addition + group_count_ = problem_desc_vec.size(); + + if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && + group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size())) + { + throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); + } + + if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size())) + { + throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size"); + } + + grid_size_ = 0; + + for(std::size_t i = 0; i < group_count_; i++) + { + const auto p_a_grid = static_cast(p_a_vec[i]); + const auto p_b_grid = static_cast(p_b_vec[i]); + const auto p_b1_grid = static_cast(p_b1_vec[i]); + const auto p_c_grid = static_cast(p_c_vec[i]); + + const auto& problem_desc = problem_desc_vec[i]; + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); + + const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); + const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); + const index_t grid_size_grp = + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + // batch stride + const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( + a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n); + + // C0 mask + const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); + + grid_size_ += grid_size_grp; + + // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and + // so on + if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias && + problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias && + problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias && + problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias)) + { + throw std::runtime_error( + "wrong! number of biases in function argument does not " + "match that in template argument"); + } + + group_kernel_args_.push_back({p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), + compute_base_ptr_of_batch, + c0_matrix_mask, + block_2_ctile_map, + BlockStart, + BlockEnd}); + + group_device_args_.push_back( + {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]}, + {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + {problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + {problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1], + problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]}, + {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], + problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, + c_grid_desc_m_n}); + } + } + + std::vector group_kernel_args_; + std::vector group_device_args_; + + std::size_t group_count_; + index_t grid_size_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + } + + hipGetErrorString(hipMemcpy(arg.p_workspace_, + arg.group_kernel_args_.data(), + arg.group_kernel_args_.size() * sizeof(GroupKernelArg), + hipMemcpyHostToDevice)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(all_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else if(!some_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + throw std::runtime_error("wrong! all gemm problems have to simultaneously meet " + "has_main_k_block_loop or no_main_k_block_loop"); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // TODO ANT: Check if tensor specialization & strides mismatch + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto& kernel_arg = arg.group_kernel_args_[i]; + const auto& device_arg = arg.group_device_args_[i]; + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); + const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) + { + return false; + } + + // Check if having main loop + const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * + kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + + // Note: we need raw lengths since threadwise copy can not handle vector load when + // part of vector is out of bounds + const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 + ? device_arg.a_mz_kz_strides_[1] + : device_arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 + ? device_arg.b_nz_kz_strides_[1] + : device_arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? device_arg.b1_nz_kz_strides_[1] + : device_arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be + // contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_)) + { + return false; + } + } + + // all gemm problems have to simultaneously meet has_main_k_block_loop or + // no_main_k_block_loop + if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GroupKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93202e352e8a8df349eb08f864d160e59520465b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduce : public BaseOperator +{ + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1; + + virtual std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStrides, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceMultipleReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_normalization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..227c352cbd755f9f8017cf7e3231e689447465dd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalization : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccDataType epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_savedMean, + void* p_savedInvVar, + AccElementwiseOperation acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..baa914477589848a4bc0dbf6bc06c31a9770f0f0 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_permute.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePermute : BaseOperator +{ + using Lengths = std::array; + using Strides = Lengths; + + virtual std::unique_ptr + MakeArgumentPointer(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b376c6f73f066729fa118013970682ca31b340e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* in_dev, + void* out_dev, + void* out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DevicePool2dFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..15aeb8e91cdb562cc57663f6eeecd7a72668e5a2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduce : public BaseOperator +{ + static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + virtual std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReducePtr = std::unique_ptr< + DeviceReduce>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..676e0812b740881bdd63d909b4fdf1e19aed88a3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmax : public BaseOperator +{ + // + // @brief Makes a pointer to Argument class. + // + // @param[in] inLengths Input tensor extent(s) from high to low dimension + // @param[in] inStrides Input tensor stride(s) from high to low dimension + // @param[in] reduceDims The dimension(s) the normalization operation is applied + // @param[in] alpha Typeless pointer in host memory storing the alpha scaling + // value as type AccDataType + // @param[in] beta Typeless pointer in host memory storing the beta scaling + // value as type AccDataType + // @param[in] in_dev Typeless const pointer in device memory storing the input + // tensor + // @param out_dev Typeless pointer in device memory storing the output tensor + // @param[in] in_elementwise_op The input elementwise operation. + // @param[in] acc_elementwise_op The accumulation elementwise operation. + // + // @return Unique pointer to the Argument class. + // + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual index_t GetRank() const = 0; + virtual index_t GetNumReduceDim() const = 0; +}; + +template +using DeviceSoftmaxPtr = std::unique_ptr< + DeviceSoftmax>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f59e6093e2ae024ccbf8827082c755eba8cee88f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceSplitKContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8eab1cdee565b4334ede5ca9bb7544ea37ff498d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + FloatDsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_akb_ak0_m_ak1; + ignore = b_grid_desc_bkb_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceSplitKContractionMultipleD_Xdl_CShuffle + : public DeviceSplitKContractionMultipleD +{ + using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(batch_stride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(batch_stride_B_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( + a_grid_desc_m_k_, split_k)}, + b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( + b_grid_desc_n_k_, split_k)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{ + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}, + split_k_{split_k} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, ""); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, + b_grid_desc_bkb_bk0_n_bk1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + + Print(); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_.GetLength(I0) << ", " + << a_grid_desc_m_k_.GetLength(I1) << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_.GetLength(I0) << ", " + << b_grid_desc_n_k_.GetLength(I1) << std::endl; + + std::cout << "A[akb, ak0, m, ak1]: " << a_grid_desc_akb_ak0_m_ak1_.GetLength(I0) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) << std::endl; + std::cout << "B[bkb, bk0, n, bk1]: " << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I0) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I1) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) << std::endl; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i].GetLength(I0) << ", " + << ds_grid_desc_m_n_[i].GetLength(I1) << std::endl; + }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_.GetLength(I0) << ", " + << e_grid_desc_m_n_.GetLength(I1) << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1_; + BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1_; + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t split_k_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) * + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AKB_AK0_M_AK1, + DeviceOp::BGridDesc_BKB_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + auto launch_kernel_atomic_add = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemmAtomicAdd::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AKB_AK0_M_AK1, + DeviceOp::BGridDesc_BKB_BK0_N_BK1, + typename GridwiseGemmAtomicAdd:: + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmAtomicAdd:: + EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, + has_main_loop>; + + hipGetErrorString(hipMemset( + arg.p_e_grid_, + 0, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(EDataType))); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + if(arg.split_k_ <= 1) + return launch_kernel(integral_constant{}); + else + return launch_kernel_atomic_add(integral_constant{}); + } + else + { + if(arg.split_k_ <= 1) + return launch_kernel(integral_constant{}); + else + return launch_kernel_atomic_add(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 2 || ABlockTransferSrcVectorDim == 3) && + (BBlockTransferSrcVectorDim == 2 || BBlockTransferSrcVectorDim == 3), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 2) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) % ABlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 2) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) % BBlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceSplitKContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fc913e9ba03611dfbabea020bd4054c0b953726d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct GemmSpecialization +{ + // Gemm + Default, + MPadding, + NPadding, + KPadding, + MNPadding, + MKPadding, + NKPadding, + MNKPadding, + // Gemm + Gemm + OPadding, + MOPadding, + NOPadding, + KOPadding, + MNOPadding, + MKOPadding, + NKOPadding, + MNKOPadding, +}; + +inline std::string getGemmSpecializationString(const GemmSpecialization& s) +{ + switch(s) + { + case GemmSpecialization::Default: return "Default"; + case GemmSpecialization::MPadding: return "MPadding"; + case GemmSpecialization::NPadding: return "NPadding"; + case GemmSpecialization::KPadding: return "KPadding"; + case GemmSpecialization::MNPadding: return "MNPadding"; + case GemmSpecialization::MKPadding: return "MKPadding"; + case GemmSpecialization::NKPadding: return "NKPadding"; + case GemmSpecialization::MNKPadding: return "MNKPadding"; + case GemmSpecialization::OPadding: return "OPadding"; + case GemmSpecialization::MOPadding: return "MOPadding"; + case GemmSpecialization::NOPadding: return "NOPadding"; + case GemmSpecialization::KOPadding: return "KOPadding"; + case GemmSpecialization::MNOPadding: return "MNOPadding"; + case GemmSpecialization::MKOPadding: return "MKOPadding"; + case GemmSpecialization::NKOPadding: return "NKOPadding"; + case GemmSpecialization::MNKOPadding: return "MNKOPadding"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2237ad944cd61775267851a9a2629efb115f286f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1040 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + FloatDsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Xdl_CShuffle + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, ""); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01f5e17d91436747d0b3bf052061741b34e0b46a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -0,0 +1,683 @@ +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for +\link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_etile_map; +#endif +} + +template +struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute +{ + using DeviceOp = DeviceBatchedGemmEPermuteXdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) + { + const auto e_grid_desc_mraw_nraw = + make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N)); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, + index_t G1, + index_t MRaw, + index_t NRaw, + index_t stride_G0, + index_t stride_G1, + index_t stride_M, + index_t stride_N) + { + const auto e_grid_desc_g0_g1_mraw_nraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(G0, G1, MRaw, NRaw), + make_tuple(stride_G0, stride_G1, stride_M, stride_N)); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_pass_through_transform(MRaw), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else + { + // not pad M or N + return e_grid_desc_g0_g1_mraw_nraw; + } + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)); + using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, + index_t Batchstride_B, + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) + : Batchstride_A_(Batchstride_A), + Batchstride_B_(Batchstride_B), + e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_B_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1); + index_t b0 = g_idx / G1; + index_t b1 = g_idx - b0 * G1; // g_idx % G1 + return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0)); + } + + private: + index_t Batchstride_A_; + index_t Batchstride_B_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + }; + + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + ck::Tuple<>, // DsDataType, + EDataType, // EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + Tuple<>, + EGridDesc_M_N, + NumPrefetch, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + BatchCount_(BatchCount), + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_, + batched_gemm_e_permute_desc.N_, + batched_gemm_e_permute_desc.stride_M_, + batched_gemm_e_permute_desc.stride_N_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_g0_g1_m_n_{ + DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_, + batched_gemm_e_permute_desc.G1_, + batched_gemm_e_permute_desc.M_, + batched_gemm_e_permute_desc.N_, + batched_gemm_e_permute_desc.stride_G0_, + batched_gemm_e_permute_desc.stride_G1_, + batched_gemm_e_permute_desc.stride_M_, + batched_gemm_e_permute_desc.stride_N_)}, + compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ck::Tuple<>{}, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // batch count + index_t BatchCount_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + + // for calculating Batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + ck::Tuple<>{}, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_e_permute_xdl< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + EDataType, + remove_reference_t, + remove_reference_t, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_e_grid_, + arg.BatchCount_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + ck::Tuple<>{}, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + EDataType* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_e, + M, + N, + K, + stride_A, + stride_B, + batch_stride_A, + batch_stride_B, + batched_gemm_e_permute_desc, + BatchCount, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + M, + N, + K, + stride_A, + stride_B, + batch_stride_A, + batch_stride_B, + batched_gemm_e_permute_desc, + BatchCount, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmEPermuteXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b87e56337f95e2463a8211c961504fceb764d8b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -0,0 +1,747 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm +{ + using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + b1_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, + raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl; + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_m_n_k_o_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; + const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; + const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; + const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = + is_same_v ? KRaw : MRaw; + const auto b_extent_lowest = + is_same_v ? NRaw : KRaw; + const auto b1_extent_lowest = + is_same_v ? Gemm1NRaw : NRaw; + const auto c_extent_lowest = + is_same_v ? Gemm1NRaw : MRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_b1, p_c, MRaw, + NRaw, KRaw, Gemm1NRaw, Batch, StrideA, + StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, + BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, + b1_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA, + StrideB, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideB1, + BatchStrideC, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e34c19bdf02841d0f9c0b1a381570f7e34120b60 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -0,0 +1,716 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ + +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_etile_map; +#endif +} + +template +struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD +{ + using DeviceOp = DeviceBatchedGemmMultiD_Xdl; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + Batch_(Batch), + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Batch + index_t Batch_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultiD_Xdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = + kernel_batched_gemm_xdl; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultiD_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19e2649e7ebffe60b7ccba1ea1d9eaecfcc95356 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,951 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_xdl_cshuffle_v1( + const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + const A0ElementwiseOperation a0_element_op, + const B0ElementwiseOperation b0_element_op, + const CDE0ElementwiseOperation cde0_element_op, + const B1ElementwiseOperation b1_element_op, + const CDE1ElementwiseOperation cde1_element_op, + const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap block_2_e1tile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); + + static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); + p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; + }); + + GridwiseGemm::template Run(p_a0_grid + a_batch_offset, + p_b0_grid + b_batch_offset, + p_d0s_grid, + p_b1_grid + b1_batch_offset, + p_d1s_grid, + p_e1_grid + c_batch_offset, + p_shared, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op, + a0_grid_desc_ak0_m_ak1, + b0_grid_desc_bk0_n_bk1, + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + b1_grid_desc_bk0_n_bk1, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_e1tile_map); +#else + ignore = p_a0_grid; + ignore = p_b0_grid; + ignore = p_d0s_grid; + ignore = p_b1_grid; + ignore = p_d1s_grid; + ignore = p_e1_grid; + ignore = a0_element_op; + ignore = b0_element_op; + ignore = cde0_element_op; + ignore = b1_element_op; + ignore = cde1_element_op; + ignore = a0_grid_desc_ak0_m_ak1; + ignore = b0_grid_desc_bk0_n_bk1; + ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_e1tile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; +#endif +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle + : public DeviceBatchedGemmMultipleDGemmMultipleD +{ + using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto gemm0_padder = + GemmPadder_v2{ + Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock}; + + static constexpr auto gemm1_padder = + GemmPadder_v2{ + Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock}; + + // for Gemm0 + static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0) + { + const auto a0_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA0, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA0)); + } + }(); + + return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw); + } + + // for Gemm0 + static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b0_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw); + } + + // for Gemm0 + template + static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0) + { + const auto d0_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideD0, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideD0)); + } + }(); + + return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw); + } + + // for Gemm1 + static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw); + } + + // for Gemm1 + template + static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1) + { + const auto e1_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE1, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE1)); + } + }(); + + return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw); + } + + static auto MakeD0sGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeD0GridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static auto MakeD1sGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeE1GridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA0_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideE1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA0_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d1_idx]); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE1_); + } + + template + __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA0_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideE1_; + }; + + using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1)); + using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1)); + using D0sGridDesc_M_N = remove_cvref_t; + using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1)); + using D1sGridDesc_M_N = remove_cvref_t; + using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< + A0DataType, // TODO: distinguish A/B datatype + Acc0DataType, + D0sDataType, + Acc1DataType, + C1ShuffleDataType, + D1sDataType, + E1DataType, + A0ElementwiseOperation, + B0ElementwiseOperation, + CDE0ElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + InMemoryDataOperationEnum::Set, + A0GridDesc_M_K, + B0GridDesc_N_K, + D0sGridDesc_M_N, + B1GridDesc_N_K, + D1sGridDesc_M_N, + E1GridDesc_M_N, + NumGemm0KPrefetchStage, + BlockSize, + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + A0K1, + B0K1, + B1K1, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + Gemm1NXdlPerWave, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0BlockTransferSrcAccessOrder, + A0BlockTransferSrcVectorDim, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + true, + A0BlockLdsExtraM, + B0BlockTransferThreadClusterLengths_BK0_N_BK1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_BK1, + true, + B0BlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + C1ShuffleMXdlPerWavePerShuffle, + C1ShuffleGemm0NXdlPerWavePerShuffle, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using A0GridDesc_AK0_M_AK1 = remove_cvref_t; + using B0GridDesc_BK0_N_BK1 = remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const A0DataType* p_a0_grid, + const B0DataType* p_b0_grid, + std::array p_d0s_grid, + const B1DataType* p_b1_grid, + std::array p_d1s_grid, + E1DataType* p_e1_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) + : p_a0_grid_{p_a0_grid}, + p_b0_grid_{p_b0_grid}, + p_d0s_grid_{}, + p_b1_grid_{p_b1_grid}, + p_d1s_grid_{}, + p_e1_grid_{p_e1_grid}, + a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)}, + b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)}, + d0s_grid_desc_m_n_{}, + b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)}, + d1s_grid_desc_m_n_{}, + e1_grid_desc_m_n_{ + DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideE1)}, + a0_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, + b0_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + b1_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, + d1s_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, + a0_element_op_{a0_element_op}, + b0_element_op_{b0_element_op}, + cde0_element_op_{cde0_element_op}, + b1_element_op_{b1_element_op}, + cde1_element_op_{cde1_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA0, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " + << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; + std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " + << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", " + << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; + std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " + << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" + << std::endl; + std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " + << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0Layout = remove_cvref_t>; + using D0DataType = remove_cvref_t>; + + // D0 pointer + p_d0s_grid_(i) = static_cast(p_d0s_grid[i]); + + // D0 desc + d0s_grid_desc_m_n_(i) = + DeviceOp::MakeD0GridDescriptor_M_N(MRaw, NRaw, StrideD0s[i]); + }); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1Layout = remove_cvref_t>; + using D1DataType = remove_cvref_t>; + + // D1 pointer + p_d1s_grid_(i) = static_cast(p_d1s_grid[i]); + + // D1 desc + d1s_grid_desc_m_n_(i) = + DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideD1s[i]); + }); + + if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_, + b0_grid_desc_n_k_, + b1_grid_desc_n_k_, + e1_grid_desc_m_n_, + block_2_e1tile_map_)) + { + e1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e1_grid_desc_m_n_); + + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = + GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( + d0s_grid_desc_m_n_); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc_m_n_); + } + } + + // private: + // pointers + const A0DataType* p_a0_grid_; + const B0DataType* p_b0_grid_; + typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + const B1DataType* p_b1_grid_; + typename GridwiseGemm::D1sGridPointer p_d1s_grid_; + E1DataType* p_e1_grid_; + + // tensor descriptors for problem definiton + A0GridDesc_M_K a0_grid_desc_m_k_; + B0GridDesc_N_K b0_grid_desc_n_k_; + D0sGridDesc_M_N d0s_grid_desc_m_n_; + B1GridDesc_N_K b1_grid_desc_n_k_; + D1sGridDesc_M_N d1s_grid_desc_m_n_; + E1GridDesc_M_N e1_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_; + B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e1-tile map + typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_; + + // element-wise op + A0ElementwiseOperation a0_element_op_; + B0ElementwiseOperation b0_element_op_; + CDE0ElementwiseOperation cde0_element_op_; + B1ElementwiseOperation b1_element_op_; + CDE1ElementwiseOperation cde1_element_op_; + + // batch + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = arg.a0_grid_desc_m_k_.GetLength(I1); + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1< + GridwiseGemm, + A0DataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::D0sGridPointer, + typename GridwiseGemm::D1sGridPointer, + E1DataType, + A0ElementwiseOperation, + B0ElementwiseOperation, + CDE0ElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + DeviceOp::A0GridDesc_AK0_M_AK1, + DeviceOp::B0GridDesc_BK0_N_BK1, + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2E1TileMap, + ComputeBasePtrOfStridedBatch, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a0_grid_, + arg.p_b0_grid_, + arg.p_d0s_grid_, + arg.p_b1_grid_, + arg.p_d1s_grid_, + arg.p_e1_grid_, + arg.a0_element_op_, + arg.b0_element_op_, + arg.cde0_element_op_, + arg.b1_element_op_, + arg.cde1_element_op_, + arg.a0_grid_desc_ak0_m_ak1_, + arg.b0_grid_desc_bk0_n_bk1_, + arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_e1tile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const A0DataType* p_a0, + const B0DataType* p_b0, + std::array p_d0s, + const B1DataType* p_b1, + std::array p_d1s, + E1DataType* p_e1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) + { + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a0, + const void* p_b0, + std::array p_d0s, + const void* p_b1, + std::array p_d1s, + void* p_e1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) override + { + return std::make_unique(static_cast(p_a0), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_e1), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA0, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << Gemm0MPerBlock << ", " + << Gemm0NPerBlock << ", " + << Gemm0KPerBlock << ", " + << A0K1 << ", " + << B0K1 << ", " + << B1K1 << ", " + << Gemm0MPerXdl << ", " + << Gemm0NPerXdl << ", " + << Gemm0MXdlPerWave << ", " + << Gemm0NXdlPerWave << ", " + << Gemm1NXdlPerWave << "> "; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c5fdbdab0f25e9b8826ce4acb0f700ef1f63d25 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,1001 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; + }); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = compute_base_ptr_of_batch_; + ignore = block_2_ctile_map; +#endif +} + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideD) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideD_(BatchStrideD) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + template + __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, + Number reduction_idx) const + { + // TODO - Support sequence of StrideD in MakeArgument() + (void)reduction_idx; + return g_idx * static_cast(BatchStrideD_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideD_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops, + index_t Batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + Batch_(Batch), + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + compute_base_ptr_of_batch_{ + type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), + type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), + type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t Batch_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; + + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + auto casted_p_arg = dynamic_cast(p_arg); + if(casted_p_arg == nullptr) + { + return false; + } + else + { + return IsSupportedArgument(*casted_p_arg); + } + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5baa0f8d9adfb8c563cb0f2fe80c3910799ff736 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -0,0 +1,859 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + c0_matrix_mask); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; + ignore = c0_matrix_mask; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; + + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); + } + + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; + + // Argument + // FIXME: constness + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_g_n_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, + raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1], + b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1], + b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, + c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_base_ptr_of_batch_{ + a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ns_lengths; + ignore = acc0_biases_gs_ms_ns_strides; + ignore = acc1_biases_gs_ms_gemm1ns_lengths; + ignore = acc1_biases_gs_ms_gemm1ns_strides; + + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", " + << a_grid_desc_g_m_k_.GetLength(I1) << ", " + << a_grid_desc_g_m_k_.GetLength(I2) << '\n'; + // a_grid_desc_g_m_k_.Print(); + std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", " + << b_grid_desc_g_n_k_.GetLength(I1) << ", " + << b_grid_desc_g_n_k_.GetLength(I2) << '\n'; + // b_grid_desc_g_n_k_.Print(); + std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " + << b1_grid_desc_g_n_k_.GetLength(I1) << ", " + << b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; + // b1_grid_desc_g_n_k_.Print(); + std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", " + << c_grid_desc_g_m_n_.GetLength(I1) << ", " + << c_grid_desc_g_m_n_.GetLength(I2) << '\n'; + // c_grid_desc_g_m_n_.Print(); + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // tensor descriptor + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-c-tile map + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_mz_nz_kz_gemm1nz_; + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + C0MatrixMask, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_, + arg.c0_matrix_mask_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { +#if 0 + arg.Print(); +#endif + + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // TODO ANT: Check if tensor specialization & strides mismatch + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); + const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + + if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = + BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + // FIXME: constness + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, // cast in struct Argument + p_acc1_biases, // cast in struct Argument + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, + acc1_biases_gs_ms_gemm1ns_strides, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f21f2d7122ca7b31553c915f6698f249c3f0aaf --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -0,0 +1,771 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + c0_matrix_mask); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; + ignore = c0_matrix_mask; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) + +// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to +// ScaleAndResetNaNToMinusInfinity. +// if !isNan(AccElement) +// AccElement *= scale +// else +// AccElement = -INFINITY +// Otherwise, result may be wrong. + +template +struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle + : public DeviceBatchedGemmSoftmaxGemm +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + using C0MatrixMask = conditional_t, + C0MatrixMask_impl>; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + b1_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, + c0_matrix_mask_{NRaw}, + raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_m_n_k_o_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + C0MatrixMask, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_, + arg.c0_matrix_mask_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; + const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; + const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; + const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = + is_same_v ? KRaw : MRaw; + const auto b_extent_lowest = + is_same_v ? NRaw : KRaw; + const auto b1_extent_lowest = + is_same_v ? Gemm1NRaw : NRaw; + const auto c_extent_lowest = + is_same_v ? Gemm1NRaw : MRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_b1, p_c, MRaw, + NRaw, KRaw, Gemm1NRaw, Batch, StrideA, + StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, + BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, + b1_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA, + StrideB, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideB1, + BatchStrideC, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ea3296356fb1ea5cb46a52a214afc02c9b07641 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -0,0 +1,668 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +struct DeviceBatchedGemmXdl : public DeviceBatchedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + const auto a_grid_desc_k0_mp_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_mp_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto b_grid_desc_k0_np_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_np_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto c_grid_desc_mp_np = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return c_grid_desc_mp_np; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + Batch_(Batch), + a_grid_desc_k0_m_k1_{ + DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, + b_grid_desc_k0_n_k1_{ + DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, + c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, + block_2_ctile_map_{ + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + kraw_{K} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + index_t Batch_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t kraw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Batch_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Batch_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.kraw_ % K1 != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">" + << " NumGemmKPrefetchStage: " + << NumGemmKPrefetchStage << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab16a757f699bb4e4ba781873bbf87538633d9c1 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -0,0 +1,874 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + using MeanVarGridDesc_M = ScaleBiasGridDesc_M; + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const DyDataType* p_dy, + const ScaleDataType* p_scale, + const MeanVarDataType* p_savedMean, + const MeanVarDataType* p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + double epsilon, + DxDataType* p_dx, + DscaleDbiasDataType* p_dscale, + DscaleDbiasDataType* p_dbias) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnDscaleDbiasStrides_(bnDscaleDbiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_dy_(p_dy), + p_scale_(p_scale), + p_savedMean_(p_savedMean), + p_savedInvVar_(p_savedInvVar), + dy_elementwise_op_(dy_elementwise_op), + p_dx_(p_dx), + p_dscale_(p_dscale), + p_dbias_(p_dbias) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + dyStrides_ = + shuffle_tensor_dimensions(dyStrides, reduceDims); + dxStrides_ = + shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(invariant_length, reduce_length) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + + haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = (reduce_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = (invariant_length + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize; + + x_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize, numBlockTileIteration); + dy_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, dyStrides_, blkGroupSize, numBlockTileIteration); + dx_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); + scale_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); + dscale_dbias_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides); + mean_var_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); + } + + AccDataType epsilon_; + + bool haveSavedMeanInvVar_; + + std::array xyLengths_; + std::array xStrides_; + std::array dyStrides_; + std::array dxStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnDscaleDbiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const DyDataType* p_dy_; + const ScaleDataType* p_scale_; + const MeanVarDataType* p_savedMean_; + const MeanVarDataType* p_savedInvVar_; + const DyElementwiseOp dy_elementwise_op_; + DxDataType* p_dx_; + DscaleDbiasDataType* p_dscale_; + DscaleDbiasDataType* p_dbias_; + + long_index_t invariant_length; + long_index_t reduce_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + XYGridDesc_M_K x_grid_desc_m_k; + XYGridDesc_M_K dy_grid_desc_m_k; + XYGridDesc_M_K dx_grid_desc_m_k; + ScaleBiasGridDesc_M scale_grid_desc_m; + ScaleBiasGridDesc_M dscale_dbias_grid_desc_m; + MeanVarGridDesc_M mean_var_grid_desc_m; + + void* workspace_mean; + void* workspace_variance; + void* workspace_count; + + void* workspace_savedMean; + void* workspace_savedInvVar; + + void* workspace_reduce_dscale; + void* workspace_reduce_dbias; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize > 1) + { + // workspace for the partial reduced result for dscale + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; + + // workspace for the partial reduced result for dbias + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; + + if(!pArg_->haveSavedMeanInvVar_) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64; + + // workspace for welford result mean + workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64; + + // workspace for welford result inv_variance + workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64; + }; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + index_t space_sz; + + // setup buffer for the partial reduced result for dscale + pArg_->workspace_reduce_dscale = pArg_->p_workspace_; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for the partial reduced result for dbias + pArg_->workspace_reduce_dbias = + reinterpret_cast(pArg_->workspace_reduce_dscale) + space_sz; + + if(UseMultiblockInK && pArg_->blkGroupSize > 1) + { + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate mean + pArg_->workspace_mean = + reinterpret_cast(pArg_->workspace_reduce_dbias) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate varirance + pArg_->workspace_variance = reinterpret_cast(pArg_->workspace_mean) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate count + pArg_->workspace_count = reinterpret_cast(pArg_->workspace_variance) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford result mean + pArg_->workspace_savedMean = reinterpret_cast(pArg_->workspace_count) + space_sz; + + space_sz = pArg_->invariant_length * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford result inv_variance + pArg_->workspace_savedInvVar = + reinterpret_cast(pArg_->workspace_savedMean) + space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto dscale_dbias_grid_desc_m_g = + DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto dscale_dbias_grid_desc_m_k = + DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g); + using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k); + + using GridwiseWelfordSecondHalfReduceFirstHalf_ = + GridwiseWelfordSecondHalfReduceFirstHalf; + + using GridwiseReduceSecondHalfBatchNormBwdFinal_ = + GridwiseReduceSecondHalfBatchNormBackwardFinal; + + if(UseMultiblockInK && arg.blkGroupSize > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize, arg.numBlockTileIteration, arg.reduce_length); + + if(!arg.haveSavedMeanInvVar_) + { + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration, + arg.p_x_, + static_cast(arg.workspace_mean), + static_cast(arg.workspace_variance), + static_cast(arg.workspace_count)); + }; + + const auto kern_welford_second_half_reduce_first_half = + kernel_welford_second_half_reduce_first_half< + GridwiseWelfordSecondHalfReduceFirstHalf_, + XDataType, + DyDataType, + AccDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + MeanVarGridDesc_M, + MeanVarCountGridDesc_M_K, + DscaleDbiasGridDesc_M_G>; + + const auto kern_reduce_second_half_batchnorm_backward_final = + kernel_reduce_second_half_batchnorm_backward_final< + GridwiseReduceSecondHalfBatchNormBwdFinal_, + XDataType, + DyDataType, + DxDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + DscaleDbiasGridDesc_M_K, + MeanVarGridDesc_M, + ScaleBiasGridDesc_M>; + + index_t numDscaleDbiasBlockTileIteration = + (arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize; + + avg_time += launch_and_time_kernel( + stream_config, + kern_welford_second_half_reduce_first_half, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.mean_var_grid_desc_m, + mean_var_count_grid_desc_m_k, + dscale_dbias_grid_desc_m_g, + arg.blkGroupSize, + arg.numBlockTileIteration, + numDscaleDbiasBlockTileIteration, + arg.epsilon_, + arg.haveSavedMeanInvVar_, + arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr, + arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr, + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_mean), + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_variance), + arg.haveSavedMeanInvVar_ ? nullptr + : static_cast(arg.workspace_count), + arg.dy_elementwise_op_, + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_savedMean), + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_savedInvVar), + arg.p_x_, + arg.p_dy_, + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias)); + + avg_time += launch_and_time_kernel( + stream_config, + kern_reduce_second_half_batchnorm_backward_final, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.dx_grid_desc_m_k, + dscale_dbias_grid_desc_m_k, + arg.mean_var_grid_desc_m, + arg.scale_grid_desc_m, + arg.dscale_dbias_grid_desc_m, + arg.blkGroupSize, + arg.reduce_length, + arg.numBlockTileIteration, + numDscaleDbiasBlockTileIteration, + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias), + arg.haveSavedMeanInvVar_ + ? arg.p_savedMean_ + : static_cast(arg.workspace_savedMean), + arg.haveSavedMeanInvVar_ + ? arg.p_savedInvVar_ + : static_cast(arg.workspace_savedInvVar), + arg.p_x_, + arg.p_dy_, + arg.p_scale_, + arg.dy_elementwise_op_, + arg.p_dx_, + arg.p_dscale_, + arg.p_dbias_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration, arg.reduce_length); + + using GridwiseBatchNormBackwardWithBlockwiseWelford_ = + GridwiseBatchNormBackwardWithBlockwiseWelford; + + const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< + GridwiseBatchNormBackwardWithBlockwiseWelford_, + XDataType, + DyDataType, + DxDataType, + AccDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + ScaleBiasGridDesc_M, + MeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_bwd, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.dx_grid_desc_m_k, + arg.scale_grid_desc_m, + arg.dscale_dbias_grid_desc_m, + arg.mean_var_grid_desc_m, + get_reduce_count_per_thread, + arg.reduce_length, + arg.numBlockTileIteration, + arg.epsilon_, + arg.p_x_, + arg.p_dy_, + arg.p_scale_, + arg.haveSavedMeanInvVar_, + arg.p_savedMean_, + arg.p_savedInvVar_, + arg.dy_elementwise_op_, + arg.p_dx_, + arg.p_dscale_, + arg.p_dbias_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XDyDxVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->dyStrides_[NumInvariantDim - 1] != 1 || + pArg_->dxStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 || + pArg_->dxStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + + if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0) + return false; + + if(pArg_->haveSavedMeanInvVar_) + { + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0) + return false; + }; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) override + { + return std::make_unique(xyLengths, + xStrides, + dyStrides, + dxStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnDscaleDbiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_dy), + static_cast(p_scale), + static_cast(p_savedMean), + static_cast(p_savedInvVar), + dy_elementwise_op, + epsilon, + static_cast(p_dx), + static_cast(p_dscale), + static_cast(p_dbias)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormBwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a16ff765bbc5c0bc1b68d80262c88a0c8f84416 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -0,0 +1,718 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + index_t numMeanVarCountBlockTileIteration = + (arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += + launch_and_time_kernel(stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += + launch_and_time_kernel(stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + numMeanVarCountBlockTileIteration, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..29978458bb202297868af555b3ffecb617e48fb7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -0,0 +1,948 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ALayout, + typename BLayout, + typename CLayout, + typename ADataType, + typename BDataType, + typename CDataType, + typename GemmAccDataType, + typename CShuffleDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1, + index_t BK1, + index_t MPerXDL, + index_t NPerXDL, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceCGemm_4Gemm_Xdl_CShuffle + : public DeviceCGemm +{ + using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto MPerThread = Number<4>{}; + static constexpr auto AScalarPerVector = Number<4>{}; + static constexpr auto BScalarPerVector = Number<4>{}; + static constexpr auto CScalarPerVector = Number<4>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid_real, + const ADataType* p_a_grid_imag, + const BDataType* p_b_grid_real, + const BDataType* p_b_grid_imag, + CDataType* p_c_grid_real, + CDataType* p_c_grid_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_real_{p_a_grid_real}, + p_a_grid_imag_{p_a_grid_imag}, + p_b_grid_real_{p_b_grid_real}, + p_b_grid_imag_{p_b_grid_imag}, + p_c_grid_real_{p_c_grid_real}, + p_c_grid_imag_{p_c_grid_imag}, + p_aux_grid_{p_workspace}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + + const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + + if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + } + else if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + } + + p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); + } + + // private: + const ADataType* p_a_grid_real_; + const ADataType* p_a_grid_imag_; + const BDataType* p_b_grid_real_; + const BDataType* p_b_grid_imag_; + CDataType* p_c_grid_real_; + CDataType* p_c_grid_imag_; + CDataType* p_aux_grid_; + CDataType* p_aux_2_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M c_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + using Add = ck::tensor_operation::element_wise::Add; + using Subtract = ck::tensor_operation::element_wise::Subtract; + + using GridwiseBinAdd = + GridwiseElementwise_1D, + Tuple, + Tuple, + Tuple, + Add, + MPerThread, + Sequence, + Sequence>; + + using GridwiseBinSubtract = + GridwiseElementwise_1D, + Tuple, + Tuple, + Tuple, + Subtract, + MPerThread, + Sequence, + Sequence>; + + const auto add_kernel = kernel_elementwise_1d, + Tuple, + Tuple, + Tuple, + Add>; + + const auto subtract_kernel = + kernel_elementwise_1d, + Tuple, + Tuple, + Tuple, + Subtract>; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel( + stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), + make_tuple(arg.c_grid_desc_m_), + make_tuple(const_cast(arg.p_aux_grid_), + const_cast(arg.p_aux_2_grid_)), + make_tuple(arg.p_c_grid_real_), + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel( + stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), + make_tuple(arg.c_grid_desc_m_), + make_tuple(const_cast(arg.p_aux_grid_), + const_cast(arg.p_aux_2_grid_)), + make_tuple(arg.p_c_grid_imag_), + Add{}); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel( + stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), + make_tuple(arg.c_grid_desc_m_), + make_tuple(const_cast(arg.p_aux_grid_), + const_cast(arg.p_aux_2_grid_)), + make_tuple(arg.p_c_grid_real_), + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel( + stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), + make_tuple(arg.c_grid_desc_m_), + make_tuple(const_cast(arg.p_aux_grid_), + const_cast(arg.p_aux_2_grid_)), + make_tuple(arg.p_c_grid_imag_), + Add{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a_real, + const ADataType* p_a_imag, + const BDataType* p_b_real, + const BDataType* p_b_imag, + CDataType* p_c_real, + CDataType* p_c_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_real, + p_a_imag, + p_b_real, + p_b_imag, + p_c_real, + p_c_imag, + p_workspace, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a_real), + static_cast(p_a_imag), + static_cast(p_b_real), + static_cast(p_b_imag), + static_cast(p_c_real), + static_cast(p_c_imag), + static_cast(p_workspace), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceCGemm_4Gemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } + + std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + [[maybe_unused]] index_t KRaw, + [[maybe_unused]] index_t StrideA, + [[maybe_unused]] index_t StrideB, + index_t StrideC) override + { + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); + + return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..72c6d0b6f74aab5ec8bd4d0c0d0106fc0ca7bd19 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,779 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD_Xdl_CShuffle + : public DeviceContractionMultipleD +{ + using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i], + ds_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{} + { + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; + a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + + b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; + b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!(arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_ms_ns_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4760422bf44472cd128f043baf9ce3d7df8f3a00 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,787 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + static constexpr ck::index_t NDimSpatial = 2; + + using DeviceOp = + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static constexpr auto N1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const index_t N0 = N / N1Number; + const index_t GemmK0Total = N0 * Ho * Wo; + + const index_t GemmK0S = + math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock; + const index_t GemmK0Pad = GemmKBatch * GemmK0S; + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); + + const auto out_n0_ho_wo_k_n1_grid_desc = + transform_tensor_descriptor(out_n_ho_wo_k_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0total_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)), + make_pass_through_transform(K), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0total_gemmm_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0pad_gemmm_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n0_y_ho_x_wo_c_n1_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Y), + make_pass_through_transform(Ho), + make_pass_through_transform(X), + make_pass_through_transform(Wo), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 6>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_n0_y_ho_x_wo_c_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C)), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0total_gemmn_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0pad_gemmn_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::array output_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // unmerge N to N0 and N1, where N1 equals to K1 + if(!(arg.Conv_N_ % K1 == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca79b932b6d6dd180b1b25babcf58191be35d96a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,835 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdData<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t i_ytilde, + index_t i_xtilde) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + i_ytilde, + i_xtilde); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << " ) " << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4749665c4f6faddf1d6f29721c0f1647b0c8952d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,968 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivationAdd +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + } + + using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + p_c1_grid_{p_resi_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + p_resi_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + const void* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + static_cast(p_resi_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bafbfe4d70e0455eb8f5d1ca74b0524420075b66 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,925 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + InMemoryDataOperationEnum OutGlobalMemoryDataOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivation +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + OutGlobalMemoryDataOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a6d24bf6c51de84a9720d8b3b4ec81b73ef6517 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,893 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXdl, + ck::index_t NPerXdl, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // TODO: Add ShuffleType for DeviceConv2d + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock * K1, + K1, // AK1 + K1, // BK1 + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + } + + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout + << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" + "nwavenperxdl_{ " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I0) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I1) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I2) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I3) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I4) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I5) + << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5821e06b2c9c31cbb93313236369179927fa2ab7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,733 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f69d8f18ae036b8e7cd060a80a0cca12bfc2a050 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP +#define DEVICE_CONV3D_FWD_NAIVE_HPP + +#include +#include +#include +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "naive_conv_fwd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : params_{3, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + out_spatial_lengths_{output_spatial_lengths}, + p_in_{p_in}, + p_wei_{p_wei}, + p_out_{p_out}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + + { + } + + // private: + utils::conv::ConvParams params_; + std::vector out_spatial_lengths_; + + const InDataType* p_in_; + const WeiDataType* p_wei_; + OutDataType* p_out_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto naive_conv3d_fwd = + ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk; + + float ave_time = launch_and_time_kernel(stream_config, + naive_conv3d_fwd, + dim3(256), + dim3(256), + 0, + arg.p_in_, + arg.p_wei_, + arg.p_out_, + arg.N_, + arg.K_, + arg.C_, + arg.in_spatial_lengths_[0], + arg.in_spatial_lengths_[1], + arg.in_spatial_lengths_[2], + arg.filter_spatial_lengths_[0], + arg.filter_spatial_lengths_[1], + arg.filter_spatial_lengths_[2], + arg.out_spatial_lengths_[0], + arg.out_spatial_lengths_[1], + arg.out_spatial_lengths_[2], + arg.conv_filter_strides_[0], + arg.conv_filter_strides_[1], + arg.conv_filter_strides_[2], + arg.conv_filter_dilations_[0], + arg.conv_filter_dilations_[1], + arg.conv_filter_dilations_[2], + arg.in_left_pads_[0], + arg.in_left_pads_[1], + arg.in_left_pads_[2]); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + std::vector out_spatial_lengths = arg.params_.GetOutputSpatialLengths(); + + bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && + out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && + out_spatial_lengths[2] == arg.out_spatial_lengths_[2]; + return out_lengths_are_consistent; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f950538d01f8abdf5b43b84119d7f0020c9aa1f3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_XDL_HPP +#define DEVICE_CONV3D_FWD_XDL_HPP + +#include +#include +#include +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "convolution_forward_specialization.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \see \link impl/device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink. + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = num_batches; + ignore = a_batch_stride; + ignore = b_batch_stride; + ignore = c_batch_stride; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + /* + * \brief Split the number of batches, \p N, into N = B * N1, such that the memory + * space of input and output tensors stays with the value range of index_t, and each subbatch + * can be dealed with GridwiseGemm. + */ + static index_t GetMaxAllowableSubBatchSize(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector output_spatial_lengths) + { + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + // N1 should satisfy that + // 1) N % N1 = 0; + // 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1) + // 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1) + // + // Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM. + auto N1 = N + 1; + + const auto stride = + math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C); + const index_t max_stride = NumericLimits::Max(); + + for(index_t n0 = 1; n0 <= N; ++n0) + { + index_t n1 = N / n0; + if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride) + { + N1 = n1; + break; + } + } + + const auto B = N / N1; + if(B * N1 != N) + { + throw std::runtime_error(__func__ + + std::string(": failed to find num_subbatches for conv3d.\n")); + } + + return N1; + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + assert(input_spatial_lengths.size() > 2); + assert(filter_spatial_lengths.size() > 2); + assert(conv_filter_strides.size() > 2); + assert(conv_filter_dilations.size() > 2); + assert(input_left_pads.size() > 2); + assert(input_right_pads.size() > 2); + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default, + "Wrong! This specialization not implemented!"); + + const auto in_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto wei_desc_k_z_y_x_c = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto out_desc_n_do_ho_wo_k = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + + const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + in_desc_n_di_hi_wi_c, + wei_desc_k_z_y_x_c, + out_desc_n_do_ho_wo_k, + make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), + make_tuple( + conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]), + make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), + make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), + Number{}); + + return descs; + } + + using ABCGridDescs = remove_cvref_t; + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + InDataType, + AccDataType, + OutDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t M01, + index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in}, + p_b_grid_{p_wei}, + p_c_grid_{p_out}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + const index_t subbatch_size = + GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths); + num_subbatches_ = N / subbatch_size; + + const auto descs = + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(subbatch_size, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + b_batch_stride_ = 0; + c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize(); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const InDataType* p_a_grid_; + const WeiDataType* p_b_grid_; + OutDataType* p_c_grid_; + index_t num_subbatches_; + index_t a_batch_stride_; + index_t b_batch_stride_; + index_t c_batch_stride_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; + std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * + arg.num_subbatches_; + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1dc39d9caac5ba4a5a4cea4c25cc92b84ce6998b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -0,0 +1,1583 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvNdBwdDataNwcKxcNwk_Dl + : public DeviceConvBwdData< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Dl; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + + std::vector a_grid_desc_k0_m0_m1_k1_container_; + std::vector b_grid_desc_k0_n0_n1_k1_container_; + std::vector c_grid_desc_m0_m10_m11_n0_n10_n11_container_; + + std::vector block_2_ctile_map_container_; + + // element-wise op + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = kernel_gemm_dl_v1r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + has_main_loop, + has_double_loop>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_container_[i], + arg.b_grid_desc_k0_n0_n1_k1_container_[i], + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i], + arg.block_2_ctile_map_container_[i]); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(!(ck::get_device_name() == "gfx906"||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) + { + return false; + } + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // matrix A + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t K = arg.Conv_K_; + + if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + + // matrix B + { + auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1) + { + return false; + } + if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 || + srcLoadLenghts[I2] % srcVectorLengths[I2] != 0) + { + return false; + } + + const index_t C = arg.Conv_K_; + + if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0) + { + return false; + } + } + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = " + << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl; + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvNdBwdDataNwcKxcNwk_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e10e374b06465ef587f853740971e3bcb5739abf --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -0,0 +1,1568 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvNdBwdDataNwcKxcNwk_Xdl + : public DeviceConvBwdData< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( + descs[I2])); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + } + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvNdBwdDataNwcKxcNwk_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8e6288009860ad9fa950310f15e7b0ffcffc7486 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwise + : public DeviceElementwiseBase +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NumDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + template + static auto GenerateInOutGrid1dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 1) + { + return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); + } + else + { + return MakeDescriptor_M({1}, {1}, 1, 1); + }; + }, + Number{}); + }; + + using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + + in_grid_1d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + lengths, inStridesArray[I.value], gridSize_, blockSize_); + }, + Number{}); + + out_grid_1d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + lengths, outStridesArray[I.value], gridSize_, blockSize_); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + InGrid1dDescTuple in_grid_1d_desc_tuple_; + OutGrid1dDescTuple out_grid_1d_desc_tuple_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.in_grid_1d_desc_tuple_, + arg.out_grid_1d_desc_tuple_, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector) { + if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) + return true; + + if(strides.back() != 1 && scalarPerVector == 1) + return true; + + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) + valid = false; + }); + + return valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8ffc5ef9fb48b85c6f685166a4a842da09e95acf --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp @@ -0,0 +1,592 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// X = Elementwise(input1, input2, input3, ...) +// Y = Normalization(X, beta, gamma) +namespace ck { +template // Descriptor of inputs, Gamma, Beta +__global__ void kernel_elementwise_layernorm( + const InGrid2dDescTuple in_grid_2d_desc_tuple, // Descriptor tuple of inputs + const GridDesc_M_K x_grid_desc_m_k, // Descriptor of X + const GridDesc_M_K gamma_grid_desc_m_k, // Descriptor of gamma + const GridDesc_M_K beta_grid_desc_m_k, // Descriptor of beta + const GridDesc_M_K y_grid_desc_m_k, // Descriptor of Y + index_t num_k_block_tile_iteration, // + AccDataType epsilon, // Datatype of epsilon + const InDataTypePointerTuple p_in_global_tuple, // Ptr tuple of input matrixs + const GammaDataType* const __restrict__ p_gamma_global, // Ptr of gamma + const BetaDataType* const __restrict__ p_beta_global, // Ptr of beta + YDataType* const __restrict__ p_y_global, // Ptr of y + const XElementwiseOperation x_elementwise_op, // Operation of input + const YElementwiseOperation y_elementwise_op) // Operation of output of normalization +{ + extern __shared__ XDataType p_x_lds[]; + GridwiseElementwiseReduction::Run(in_grid_2d_desc_tuple, // Descriptor tuple of inputs + x_grid_desc_m_k, // Descriptor of X + gamma_grid_desc_m_k, // Descriptor of Gamma + beta_grid_desc_m_k, // Descriptor of Beta + y_grid_desc_m_k, // Descriptor of Y + num_k_block_tile_iteration, // + epsilon, // epsilon + p_in_global_tuple, // Ptr tuple of inputs + p_x_lds, // Ptr of X + p_gamma_global, // Ptr of gamma + p_beta_global, // Ptr of beta + p_y_global, // Ptr of Y + x_elementwise_op, // Operation of input + y_elementwise_op); // Operation of output of normalization +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = LayerNorm(A + B, Beta, Gamma) +template // Size to write destination Y +struct DeviceElementwiseNormalizationImpl + : public DeviceElementwiseNormalization +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + + using XDataType = YDataType; + + static_assert( + (KThreadSliceSize % GammaSrcVectorSize == 0), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (KThreadSliceSize % BetaSrcVectorSize == 0), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static constexpr index_t M_BlockTileSize = + MThreadClusterSize * MThreadSliceSize; // num of rows calculated in a block + static constexpr index_t K_BlockTileSize = + KThreadClusterSize * KThreadSliceSize; // num of columns calculated in a block + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto GenerateSrcGrid2dDescTuple(Number) + { + return generate_tuple([&](auto) { return MakeSrc2dDescriptor({1}, {1}, 1, 1); }, + Number{}); + }; + + using InGrid2dDescTuple = decltype(GenerateSrcGrid2dDescTuple(Number{})); + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseReduceLayernormGeneric = + GridwiseElementwiseLayernormWelfordVariance_mk_to_mk; + + using GridwiseReduceLayernormSweepOnce = + GridwiseElementwiseLayernormWelfordVariance_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op, + AccDataType epsilon, + const std::array in_dev_buffers, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : epsilon_(epsilon), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + x_elementwise_op_(x_elementwise_op), + y_elementwise_op_(y_elementwise_op) + { + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + for(int i = 0; i < NumInput; i++) + { + inStridesArray_[i] = + shuffle_tensor_dimensions(inStridesArray[i], reduceDims); + } + + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(Lengths_); + + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize_; + + in_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeSrc2dDescriptor( + Lengths_, inStridesArray_[I.value], blkGroupSize_, numBlockTileIteration_); + }, + Number{}); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); + + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + + sweep_once_ = + x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + if(!sweep_once_) // if not sweep once, compute memory size for matrix X in lds for + // store Intermediate results + { + int block_TileSize = M_BlockTileSize * reduce_total_length; + x_lds_size_ = block_TileSize * sizeof(XDataType); + } + else + x_lds_size_ = 0; + } + + AccDataType epsilon_; + + InDataTypePointerTuple in_dev_buffers_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + + std::vector Lengths_; + std::array, NumInput> inStridesArray_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + XElementwiseOperation x_elementwise_op_; + YElementwiseOperation y_elementwise_op_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + InGrid2dDescTuple in_grid_2d_desc_tuple_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K beta_grid_desc_m_k_; + GridDesc_M_K y_grid_desc_m_k_; + bool sweep_once_; + int x_lds_size_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + arg.sweep_once_ ? kernel_elementwise_layernorm + : kernel_elementwise_layernorm; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + arg.x_lds_size_, + arg.in_grid_2d_desc_tuple_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.in_dev_buffers_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.x_elementwise_op_, + arg.y_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + for(int i = 0; i < NumInput; i++) + { + if(p_arg_->inStridesArray_[i][NumInvariantDim - 1] != 1) + return false; + } + + if(p_arg_->inStridesArray_[0][NumInvariantDim - 1] != 1 && + p_arg_->inStridesArray_[1][NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + }; + } + else + { + for(int i = 0; i < NumInput; i++) + { + if(p_arg_->inStridesArray_[i][Rank - 1] != 1) + return false; + } + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + }; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = KThreadSliceSize % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize)) + return false; + + if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize)) + return false; + + // if fastest dim is not reduced + if constexpr(XYSrcVectorDim == 0) // + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + + // if fastest dim is not reduced + if constexpr(XYSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return (false); + } + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccDataType epsilon, + const std::array in_dev_buffers, + const void* p_gamma, + const void* p_beta, + void* p_y, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + x_elementwise_op, + y_elementwise_op, + epsilon, + in_dev_buffers, + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseNormalizationImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9a64e8c4b0bde853c1296eb75413110da53fb7e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -0,0 +1,875 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); + using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + BiasDataType, + D0DataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, + c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19140688283ba4137e0420f6dcaead446c015fe6 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp @@ -0,0 +1,572 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], or A[K, N] +// input : B[K, N], or A[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute +{ + using DeviceOp = DeviceGemmBiasEPermute_Xdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static constexpr index_t NumDTensor = 1; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) + { + index_t M0 = d_e_grid_desc.M0_; + index_t M1 = d_e_grid_desc.M1_; + index_t M2 = d_e_grid_desc.M2_; + index_t N0 = d_e_grid_desc.N0_; + index_t N1 = d_e_grid_desc.N1_; + + index_t stride_M0 = d_e_grid_desc.stride_M0_; + index_t stride_M1 = d_e_grid_desc.stride_M1_; + index_t stride_M2 = d_e_grid_desc.stride_M2_; + index_t stride_N0 = d_e_grid_desc.stride_N0_; + index_t stride_N1 = d_e_grid_desc.stride_N1_; + + const auto e_grid_desc_mraw_nraw = [&]() { + const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( + make_tuple(M0, M1, M2, N0, N1), + make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); + + return transform_tensor_descriptor( + e_grid_desc_m0_m1_m2_n0_n1, + make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), + make_merge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); + + using DsGridDesc_M_N = Tuple; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + const void* p_d_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + + if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + // populate pointer, desc for Ds + // D pointer + p_ds_grid_(I0) = static_cast(p_d_grid); + + // D desc + ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_[I0]); + } + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_bias_e_permute< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasEPermute_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b08a418d764db72bba0c95bdc7fd5fb781b79ea7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl : public DeviceGemm + +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f1fb4ab4b1b255e99c4c5a74c07f97019cf55d87 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,682 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + FloatRsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_rs_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_rs_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = qs_element_op; + ignore = rs_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = rs_grid_desc_mblock_mperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : R0[M], R1[M], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle + : public DeviceGemmMultipleDMultipleR +{ + using DeviceOp = DeviceGemmMultipleDMultipleR_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + // assume D is packed tensor + static auto MakeRGridDescriptor_M(index_t MRaw) + { + const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(r_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return r_grid_desc_mraw; + } + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + ReduceAccDataType, + RsDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + ThreadReduceOperations, + InMemoryDataOperationEnum::Set, + RsGlobalMemoryDataOperation, + AGridDesc_M_K, + BGridDesc_N_K, + EGridDesc_M_N, + RGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + RThreadTransferDstScalarPerVector_MPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + std::array p_rs_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + p_rs_grid_{}, // FIXME + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + rs_grid_desc_mblock_mperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + qs_element_op_{qs_element_op}, + rs_element_op_{rs_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + e_grid_desc_m_n_, + r_grid_desc_m_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + + p_rs_grid_(i) = static_cast(p_rs_grid[i]); + + rs_grid_desc_mblock_mperblock_(i) = + GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); + }); + } + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + typename GridwiseGemm::RsGridPointer p_rs_grid_; + + // tensor descriptors + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N e_grid_desc_m_n_; + RGridDesc_M r_grid_desc_m_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + QsElementwiseOperation qs_element_op_; + RsElementwiseOperation rs_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_multiple_r_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + typename GridwiseGemm::RsGridPointer, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ck::StaticallyIndexedArray< + typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, + NumRTensor>, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.p_rs_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.qs_element_op_, + arg.rs_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.rs_grid_desc_mblock_mperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + p_rs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + p_rs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleDMultipleR_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3830e1c0e822a7bb86622c860f623978cf0c6d53 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,698 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cf19083954d81525018dbede5e9405c17463e0e6 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,835 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume Reduce is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21bb36b7b39eb3704ebc53dc03700ea7cc7815e5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdl : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + kraw_{K} + { + a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t kraw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + if(arg.kraw_ % K1 != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc8c8d4dab39060b4b4219d4422925e70bd492ed --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -0,0 +1,700 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffle : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + kraw_{KRaw} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t kraw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer];; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..875623dcb68373b1c4aed91a16f124cdfa4765b9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator +{ + using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeGridDescriptor_N(index_t NRaw) + { + const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw)); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad N + return transform_tensor_descriptor(grid_desc_nraw, + make_tuple(make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad N + return grid_desc_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + C0DataType, + ReduceAccDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + LoopSched>; + + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const C0DataType* p_c0_grid_add, + const C0DataType* p_c0_grid_bias, + const C0DataType* p_c0_grid_gamma, + const C0DataType* p_c0_grid_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_bias_{p_c0_grid_bias}, + p_c0_grid_add_{p_c0_grid_add}, + p_c0_grid_gamma_{p_c0_grid_gamma}, + p_c0_grid_beta_{p_c0_grid_beta}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_nblock_nperblock_{}, + block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_nblock_nperblock_ = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const C0DataType* p_c0_grid_bias_; + const C0DataType* p_c0_grid_add_; + const C0DataType* p_c0_grid_gamma_; + const C0DataType* p_c0_grid_beta_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_N c0_grid_desc_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const C0DataType* p_c0_bias, + const C0DataType* p_c0_add, + const C0DataType* p_c0_gamma, + const C0DataType* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0_bias, + p_c0_add, + p_c0_gamma, + p_c0_beta, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0_bias, + const void* p_c0_add, + const void* p_c0_gamma, + const void* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0_bias), + static_cast(p_c0_add), + static_cast(p_c0_gamma), + static_cast(p_c0_beta), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmLayerNorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42cabcea9edd1ab0d74155684f2d161d1a1409c0 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -0,0 +1,523 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSkipBLds : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + static_assert(BBlockBufferSize >= 2); + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferSrcScalarPerVector, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockBufferSize, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdlSkipBLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdlSkipBLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdlSkipBLds::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_ = + GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSkipBLds::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSkipBLds" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..50515189fa1dcdf6ecd01d1118561680f12654f5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -0,0 +1,650 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = + DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitKCShuffle::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03d9e26a460962222c90f393df7115fcd4ecdf76 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,907 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_contraction_multiple_d_xdl_cshuffle( + const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto contraction_arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(contraction_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && + block_id < contraction_arg_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < contraction_arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + contraction_arg_ptr[group_id].p_a_grid_, + contraction_arg_ptr[group_id].p_b_grid_, + contraction_arg_ptr[group_id].p_ds_grid_, + contraction_arg_ptr[group_id].p_e_grid_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].block_2_etile_map_); +#else + ignore = contraction_args; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceGroupedContractionMultipleD_Xdl_CShuffle + : public DeviceGroupedContractionMultipleD +{ + using DeviceOp = DeviceGroupedContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i], + ds_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + struct GroupedContractionBlock2ETileMap + { + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, + ck::index_t BlockStart) + { + default_block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + block_start_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return default_block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - block_start_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return default_block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap default_block_2_etile_map_; + ck::index_t block_start_; + }; + + struct ContractionMultiDKernelArg + { + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // lock-to-e-tile map + GroupedContractionBlock2ETileMap block_2_etile_map_; + + ck::index_t block_start_, block_end_; + }; + + struct ContractionMultiDDeviceArg + { + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + // index_t e_mz_stride_; + index_t e_nz_stride_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + group_count_ = contraction_descs.size(); + + if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && + group_count_ == p_e_vec.size())) + { + throw std::runtime_error("wrong! group_count_ != a/b/e_vec.size"); + } + + contraction_multi_d_kernel_args_.reserve(group_count_); + + grid_size_ = 0; + + for(std::size_t i = 0; i < group_count_; i++) + { + const auto p_a_grid = static_cast(p_a_vec[i]); + const auto p_b_grid = static_cast(p_b_vec[i]); + const auto p_e_grid = static_cast(p_e_vec[i]); + + const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K( + contraction_descs[i].a_ms_ks_lengths, contraction_descs[i].a_ms_ks_strides); + const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K( + contraction_descs[i].b_ns_ks_lengths, contraction_descs[i].b_ns_ks_strides); + + DsGridDesc_M_N ds_grid_desc_m_n; + typename GridwiseGemm::DsGridPointer p_ds_grid; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid(j) = static_cast(p_ds_vec[i][j]); + + // D desc + ds_grid_desc_m_n(j) = + DeviceOp::MakeEGridDescriptor_M_N(contraction_descs[i].ds_ms_ns_lengths[j], + contraction_descs[i].ds_ms_ns_strides[j]); + }); + + const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( + contraction_descs[i].e_ms_ns_lengths, contraction_descs[i].e_ms_ns_strides); + + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + const index_t grid_size_grp = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) + .CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + const auto block_2_etile_map = + GroupedContractionBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + // for sanity check of vector memory access + const index_t a_mz_stride = contraction_descs[i].a_ms_ks_strides[NumDimM - 1]; + const index_t a_kz_stride = + contraction_descs[i].a_ms_ks_strides[NumDimM + NumDimK - 1]; + + const index_t b_nz_stride = contraction_descs[i].b_ns_ks_strides[NumDimN - 1]; + const index_t b_kz_stride = + contraction_descs[i].b_ns_ks_strides[NumDimN + NumDimK - 1]; + + std::array ds_nz_stride; + for(index_t j = 0; j < NumDTensor; ++j) + { + ds_nz_stride[j] = + contraction_descs[i].ds_ms_ns_strides[j][NumDimM + NumDimN - 1]; + } + + const index_t e_nz_stride = + contraction_descs[i].e_ms_ns_strides[NumDimM + NumDimN - 1]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + contraction_multi_d_kernel_args_.push_back( + {p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + + contraction_multi_d_device_args_.push_back({a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_mz_stride, + a_kz_stride, + b_nz_stride, + b_kz_stride, + ds_nz_stride, + e_nz_stride}); + } + } + } + + std::vector contraction_multi_d_kernel_args_; + std::vector contraction_multi_d_device_args_; + + std::size_t group_count_; + index_t grid_size_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto K = + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString(hipMemcpy(arg.p_workspace_, + arg.contraction_multi_d_kernel_args_.data(), + arg.contraction_multi_d_kernel_args_.size() * + sizeof(ContractionMultiDKernelArg), + hipMemcpyHostToDevice)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = + kernel_grouped_contraction_multiple_d_xdl_cshuffle; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + }; + + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto a_grid_desc_m_k_ = arg.contraction_multi_d_device_args_[i].a_grid_desc_m_k_; + const auto b_grid_desc_n_k_ = arg.contraction_multi_d_device_args_[i].b_grid_desc_n_k_; + const auto ds_grid_desc_m_n_ = + arg.contraction_multi_d_device_args_[i].ds_grid_desc_m_n_; + const auto e_grid_desc_m_n_ = arg.contraction_multi_d_device_args_[i].e_grid_desc_m_n_; + const auto a_grid_desc_ak0_m_ak1_ = + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_; + const auto b_grid_desc_bk0_n_bk1_ = + arg.contraction_multi_d_kernel_args_[i].b_grid_desc_bk0_n_bk1_; + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + arg.contraction_multi_d_kernel_args_[i] + .ds_grid_desc_mblock_mperblock_nblock_nperblock_; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock_ = + arg.contraction_multi_d_kernel_args_[i] + .e_grid_desc_mblock_mperblock_nblock_nperblock_; + + const auto block_2_etile_map_ = + arg.contraction_multi_d_kernel_args_[i].block_2_etile_map_; + + const auto a_mz_stride_ = arg.contraction_multi_d_device_args_[i].a_mz_stride_; + const auto a_kz_stride_ = arg.contraction_multi_d_device_args_[i].a_kz_stride_; + const auto b_nz_stride_ = arg.contraction_multi_d_device_args_[i].b_nz_stride_; + const auto b_kz_stride_ = arg.contraction_multi_d_device_args_[i].b_kz_stride_; + const auto ds_nz_stride_ = arg.contraction_multi_d_device_args_[i].ds_nz_stride_; + const auto e_nz_stride_ = arg.contraction_multi_d_device_args_[i].e_nz_stride_; + + if(!GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(a_mz_stride_ == 1 && + a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(a_kz_stride_ == 1 && + a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(b_nz_stride_ == 1 && + b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(b_kz_stride_ == 1 && + b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(!(ds_nz_stride_[j] == 1 && + ds_grid_desc_mblock_mperblock_nblock_nperblock_[j].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!(e_nz_stride_ == 1 && e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + return false; + } + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a_vec, + p_b_vec, + p_ds_vec, + p_e_vec, + contraction_descs, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a_vec, + p_b_vec, + p_ds_vec, + p_e_vec, + contraction_descs, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * + sizeof(ContractionMultiDKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..682aba086008b607ac2f4313b6de45e66008bb63 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -0,0 +1,1015 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 + : public DeviceGroupedConvBwdDataMultipleD +{ + // FIXME + static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO make A/B datatype different + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + num_gemm_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + // problem definition + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + // number of GEMM + num_gemm_ = YTilde * XTilde; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + // desc for problem definition + const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(index_t i = 0; i < num_gemm_; i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t num_gemm_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_etile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(index_t i = 0; i < arg.num_gemm_; i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i])) + { + throw std::runtime_error("wrong! device_op has invalid setting"); + } + + const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_etile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specifialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i])) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) + << ">"; + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d9e7b54cc8b14cb5ba2cd93d4dd295b6e3670f71 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp @@ -0,0 +1,1244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +struct ComputePtrOffsetOfStridedBatch +{ + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; +}; + +} // namespace + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = batch_count; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + // type convert descs + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * 4; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + template + static auto MakeDescriptor_M0(const std::array& shape, + const std::array& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{G}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = + N * K * + std::accumulate(begin(output_spatial_lengths), + end(output_spatial_lengths), + index_t{1}, + std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideB_ = + N * C * + std::accumulate(begin(input_spatial_lengths), + end(input_spatial_lengths), + index_t{1}, + std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideC_ = + K * C * + std::accumulate(begin(filter_spatial_lengths), + end(filter_spatial_lengths), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + index_t Conv_G_; + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::array output_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + G, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + G, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03185d5b1d29186f2b2814257eacfe90da6c5e02 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Grouped Convolution Forward: +// input : input image A[G, N, C, Hi, Wi], +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : output image E[G, N, K, Ho, Wo] +// output : R0[G, N, Ho, Wo], R1[G, N, Ho, Wo], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b54ee493c417983b01da6d27b19e5a557de5431 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,1105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" +#include "ck/library/utility/numeric.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE, + Array BatchStrideRs) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE), + BatchStrideRs_(BatchStrideRs) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + __host__ __device__ constexpr auto GetRsPtrOffset(index_t g_idx) const + { + Array rs_offset; + static_for<0, NumRTensor, 1>{}( + [&](auto i) { rs_offset(i) = g_idx * static_cast(BatchStrideRs_[i]); }); + return rs_offset; + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + Array BatchStrideRs_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batch_gemm_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + RsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + RsPointer p_rs_grid_grp; + + static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); + + static_for<0, NumRTensor, 1>{}( + [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_rs_grid_grp, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + rs_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_rs_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = rs_grid_desc_mblock_mperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = qs_element_op; + ignore = rs_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +template +struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleDMultipleR +{ + using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + template + static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw) + { + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor( + descriptor, + make_tuple(make_right_pad_transform(descriptor, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return descriptor; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeRGridDescriptor_M(const std::array& r_g_n_wos_lengths, + const std::array& /* r_g_n_wos_strides */) + { + const index_t N = r_g_n_wos_lengths[1]; + + const index_t NHoWo = + N * ck::accumulate_n( + r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>()); + + const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo)); + + return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); + } + + template || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeRGridDescriptor_M(const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides) + { + const index_t N = r_g_n_wos_lengths[1]; + + const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = + N * ck::accumulate_n( + r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>()); + + const auto r_grid_desc_mraw = + make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride)); + + return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); + } + + using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using RGridDesc_M = remove_cvref_t({}, {}))>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + ReduceAccDataType, + RsDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + ThreadReduceOperations, + InMemoryDataOperationEnum::Set, + RsGlobalMemoryDataOperation, + AGridDesc_M_K, + BGridDesc_N_K, + EGridDesc_M_N, + RGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + RThreadTransferDstScalarPerVector_MPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + p_rs_grid_{}, // FIXME + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + r_grid_desc_m_{ + DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + rs_grid_desc_mblock_mperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + qs_element_op_{qs_element_op}, + rs_element_op_{rs_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + e_grid_desc_m_n_, + r_grid_desc_m_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_(i)); + }); + + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + + // R pointer + p_rs_grid_(i) = static_cast(p_rs[i]); + + rs_grid_desc_mblock_mperblock_(i) = + GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); + }); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + typename GridwiseGemm::RsGridPointer p_rs_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + RGridDesc_M r_grid_desc_m_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + QsElementwiseOperation qs_element_op_; + RsElementwiseOperation rs_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * + arg.a_g_n_c_wis_lengths_[0]; // Group count + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + typename GridwiseGemm::RsGridPointer, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ck::StaticallyIndexedArray< + typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, + NumRTensor>, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.p_rs_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.qs_element_op_, + arg.rs_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.rs_grid_desc_mblock_mperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of R + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v)) + { + return false; + } + + // check Gridwise GEMM + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + p_rs, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + r_g_n_wos_lengths, + r_g_n_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + p_rs, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + r_g_n_wos_lengths, + r_g_n_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bb7a2f8c0a78963fd25380c50bd32615581eb8fe --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,952 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aabcc73a04061d851bb6de5130fbb8f6126c8bb5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -0,0 +1,677 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].block_2_etile_map_); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm +{ + using DeviceOp = DeviceGroupedGemm_Xdl; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = remove_cvref_t; + using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + struct GroupedGemmBlock2ETileMap + { + using Block2ETileMap = + remove_cvref_t; + + GroupedGemmBlock2ETileMap() + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); + BlockStart_ = -1; + } + + GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_; + }; + + struct GemmBiasTransKernelArg + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + typename GridwiseGemm::DsGridPointer ds_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideC = gemm_descs[i].stride_C_; + + // pointer + typename GridwiseGemm::DsGridPointer p_ds_grid{}; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + }); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB); + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideC); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + const index_t grid_size_grp = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) + .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + // tensor descriptors for block/thread-wise copy + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmBiasTransKernelArg{static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t group_count_; + index_t skipped_group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2) + << "}"; + + std::cout << ", arg.b_grid_desc_bk0_n_bk1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2) + << "}"; + + std::cout << ", arg.e_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_, + arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString( + hipMemcpy(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg), + hipMemcpyHostToDevice)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_grouped_gemm_xdl; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dbeeb980a56b34f27cef97d266845ff88c10c9bc --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypeTuple::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static_assert(sequence_all_of(OutDstVectorSizeSeq{}, + [](auto vectorSize) { + return (MThreadSliceSize % vectorSize == 0); + }), + "The OutDstVectorSize should completely divide the MThreadSliceSize!"); + + static constexpr bool CheckDataTypeTuple() + { + bool flag = true; + + static_for<0, NumReduction, 1>{}([&](auto I) { + using OutDataType = remove_cvref_t; + flag = + flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType::value; + }); + + return flag; + }; + + static_assert(CheckDataTypeTuple(), + "The OutDataType must support the specified OutMemoryDataOperation!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptor(std::array{}, + std::array{}); + }, + Number{}); + }; + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor( + std::array{}, std::array{}, 1, 1)); + using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple()); + + static auto MakeDst1dDescriptorForBufferSet(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple_2() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptorForBufferSet(std::array{}, + std::array{}); + }, + Number{}); + }; + + using OutGridDesc_M_Tuple_2 = decltype(GenerateOutGrid1dDescTuple_2()); + + struct Argument : public BaseArgument + { + Argument(const std::array& inLengths, + const std::array& inStrides, + const std::array& outLengths, + const std::array, NumReduction>& outStridesArray, + const std::array& reduceDims, + const std::array& alphas, + const std::array& betas, + const void* in_dev, + const std::array& out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) + : outLengths_{outLengths}, + outStridesArray_{outStridesArray}, + in_elementwise_op_tuple_{in_elementwise_op_tuple}, + acc_elementwise_op_tuple_{acc_elementwise_op_tuple} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + for(size_t i = 0; i < NumReduction; i++) + { + alpha_values_(i) = *static_cast(alphas[i]); + beta_values_(i) = *static_cast(betas[i]); + }; + + in_dev_ = static_cast(in_dev); + + out_dev_buffers_ = generate_tuple( + [&](auto iR) { + using OutDataTypePointer = + remove_cvref_t; + using OutDataType = remove_cvref_t>; + return static_cast(out_dev_buffers[iR]); + }, + Number{}); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + in_grid_desc_m_k = + MakeSrc2dDescriptor(inLengths_, inStrides_, blkGroupSize, numBlockTileIteration); + + out_grid_desc_m_tuple = generate_tuple( + [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); }, + Number{}); + + out_grid_desc_m_tuple_2 = generate_tuple( + [&](auto I) { + return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]); + }, + Number{}); + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::array inLengths_; + std::array inStrides_; + + std::array outLengths_; + std::array, NumReduction> outStridesArray_; + + Array alpha_values_; + Array beta_values_; + + const InDataType* in_dev_; + OutDataTypePointerTuple out_dev_buffers_; + + InGridDesc_M_K in_grid_desc_m_k; + OutGridDesc_M_Tuple out_grid_desc_m_tuple; + OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2; + + InElementwiseOperationTuple in_elementwise_op_tuple_; + AccElementwiseOperationTuple acc_elementwise_op_tuple_; + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using GridwiseMultipleReduce = + GridwiseMultipleReduction_mk_to_m_multiblock; + + const auto kernel_main = + kernel_multiple_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + auto identity_values = generate_tuple( + [&](auto iR) { + using OutDataType = remove_cvref_t; + return ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); + }, + Number{}); + + const auto kernel_pre = kernel_multiple_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + arg.out_grid_desc_m_tuple_2, + arg.out_dev_buffers_, + identity_values); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k, + arg.out_grid_desc_m_tuple, + arg.in_elementwise_op_tuple_, + arg.acc_elementwise_op_tuple_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_values_, + arg.in_dev_, + arg.beta_values_, + arg.out_dev_buffers_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(use_multiblock) + { + for(size_t i = 0; i < pArg->beta_values_.Size(); i++) + if(pArg->beta_values_[i] != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0) + return (false); + }; + // To improve + bool valid = true; + static_for<0, NumReduction, 1>{}([&](auto I) { + if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 && + OutDstVectorSizeSeq::At(I) != 1) + valid = false; + + if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0) + valid = false; + }); + + if(!valid) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + if(pArg->reduce_total_length / KThreadSliceSize < 2) + return (false); + }; + + return (true); + }; + + std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStridesArray, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStridesArray, + reduceDims, + alphas, + betas, + in_dev, + out_dev_buffers, + in_elementwise_op_tuple, + acc_elementwise_op_tuple); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ","; + str << "OutDstVectorSize"; + static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); }); + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff8465e9fc736679a05b7059d0aa4e4342404757 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypeTuple::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static_assert(sequence_all_of(OutDstVectorSizeSeq{}, + [](auto vectorSize) { + return (MThreadSliceSize % vectorSize == 0); + }), + "The OutDstVectorSize should completely divide the MThreadSliceSize!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptor(std::array{}, + std::array{}); + }, + Number{}); + }; + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array{}, + std::array{})); + using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple()); + + struct Argument : public BaseArgument + { + Argument(const std::array& inLengths, + const std::array& inStrides, + const std::array& outLengths, + const std::array, NumReduction>& outStridesArray, + const std::array& reduceDims, + const std::array& alphas, + const std::array& betas, + const void* in_dev, + const std::array& out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) + : outLengths_{outLengths}, + outStridesArray_{outStridesArray}, + in_elementwise_op_tuple_{in_elementwise_op_tuple}, + acc_elementwise_op_tuple_{acc_elementwise_op_tuple} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + for(size_t i = 0; i < NumReduction; i++) + { + alpha_values_(i) = *static_cast(alphas[i]); + beta_values_(i) = *static_cast(betas[i]); + }; + + in_dev_ = static_cast(in_dev); + + out_dev_buffers_ = generate_tuple( + [&](auto iR) { + using OutDataTypePointer = + remove_cvref_t; + using OutDataType = remove_cvref_t>; + return static_cast(out_dev_buffers[iR]); + }, + Number{}); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + in_grid_desc_m_k = MakeSrc2dDescriptor(inLengths_, inStrides_); + + out_grid_desc_m_tuple = generate_tuple( + [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); }, + Number{}); + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::array inLengths_; + std::array inStrides_; + + std::array outLengths_; + std::array, NumReduction> outStridesArray_; + + Array alpha_values_; + Array beta_values_; + + const InDataType* in_dev_; + OutDataTypePointerTuple out_dev_buffers_; + + InGridDesc_M_K in_grid_desc_m_k; + OutGridDesc_M_Tuple out_grid_desc_m_tuple; + + InElementwiseOperationTuple in_elementwise_op_tuple_; + AccElementwiseOperationTuple acc_elementwise_op_tuple_; + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using GridwiseMultipleReduce = + GridwiseMultipleReduction_mk_to_m_threadwise; + + const auto kernel_main = + kernel_multiple_reduce_threadwise; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k, + arg.out_grid_desc_m_tuple, + arg.in_elementwise_op_tuple_, + arg.acc_elementwise_op_tuple_, + arg.alpha_values_, + arg.in_dev_, + arg.beta_values_, + arg.out_dev_buffers_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0) + return (false); + }; + + // To improve + bool valid = true; + static_for<0, NumReduction, 1>{}([&](auto I) { + if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 && + OutDstVectorSizeSeq::At(I) != 1) + valid = false; + + if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0) + valid = false; + }); + + if(!valid) + return (false); + + return (true); + }; + + std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStridesArray, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStridesArray, + reduceDims, + alphas, + betas, + in_dev, + out_dev_buffers, + in_elementwise_op_tuple, + acc_elementwise_op_tuple); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceMultipleReduceThreadwise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ","; + str << "OutDstVectorSize"; + static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); }); + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..47d9df80255805729b6c36fed9ae11608052c51d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) +{ + GridwiseReduction::Run(x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + num_k_block_tile_iteration, + epsilon, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + acc_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +template +struct DeviceNormalizationImpl : public DeviceNormalization +{ + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseReduceLayernormGeneric = + GridwiseNormalizationWelfordVariance_mk_to_mk; + using GridwiseNormalizationSweepOnce = + GridwiseNormalizationWelfordVariance_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccElementwiseOperation acc_elementwise_op, + AccDataType epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : epsilon_(epsilon), + p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + acc_elementwise_op_(acc_elementwise_op) + { + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(Lengths_); + + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + + isSweeponce_ = + x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + } + + AccDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + AccElementwiseOperation acc_elementwise_op_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K beta_grid_desc_m_k_; + GridDesc_M_K y_grid_desc_m_k_; + bool isSweeponce_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = arg.isSweeponce_ + ? kernel_normalization + : kernel_normalization; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.acc_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + }; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return (false); + } + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccDataType epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvVar, + AccElementwiseOperation acc_elementwise_op) override + { + // TODO + // Optional cache of the intermediate results (mean and InvVariance) during the + // forward pass could speedup in the backward + ignore = p_saveMean; + ignore = p_saveInvVar; + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + acc_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b96373c0ff1fc86a79c4248db35dd5812f2c5b1 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_permute.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Swap last 2 dimensions +// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]] +// ^^^^^^^^^^^ +// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]] +// ^^^^^^^^^^^ +template +struct DevicePermuteImpl : DevicePermute +{ + using BaseType = DevicePermute; + using typename BaseType::Lengths; + using typename BaseType::Strides; + + static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); + static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim); + static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim); + static_assert(SrcVectorDim != DstVectorDim); + + template + static auto ConvertArrayToTuple(const std::array& array) + { + static_assert(1 <= N && N <= NumDim); + + return generate_tuple([&](auto I) { return array[I]; }, Number{}); + } + + static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride) + { + // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], + // d[NumDim-1]] + const auto desc = + make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride)); + + // merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2], + // d[NumDim-1]] + // => [N, H, W] + const index_t H = *std::next(rbegin(lengths)); + const index_t W = *rbegin(lengths); + const auto desc_n_h_w = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(ConvertArrayToTuple(lengths)), + make_pass_through_transform(H), + make_pass_through_transform(W)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{}), + Sequence{}, + Sequence{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return PadTensorDescriptor( + desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence{}); + } + + using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); + using OutGridDesc = InGridDesc; + + using GridwisePermute = GridwisePermute< + InGridDesc, + OutGridDesc, + InDataType, + OutDataType, + ElementwiseOperation, + BlockSize, + NPerBlock, + HPerBlock, + WPerBlock, + InBlockLdsExtraW, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor + DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor + SrcScalarPerVector, + DstScalarPerVector>; + + using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap; + + struct Argument : public BaseArgument + { + Argument(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) + : in_dev_buffer_(static_cast(in_dev_buffer)), + out_dev_buffer_(static_cast(out_dev_buffer)), + in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)), + out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)), + in_lengths_(in_lengths), + in_strides_(in_strides), + out_lengths_(out_lengths), + out_strides_(out_strides), + elementwise_op_(elementwise_op), + block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_)) + { + } + + const InDataType* in_dev_buffer_; + OutDataType* out_dev_buffer_; + InGridDesc in_grid_desc_; + OutGridDesc out_grid_desc_; + + Lengths in_lengths_; + Strides in_strides_; + Lengths out_lengths_; + Strides out_strides_; + + ElementwiseOperation elementwise_op_; + + Block2TileMap block_2_tile_map_; + }; + + struct Invoker : BaseInvoker + { + static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_); + + const auto kernel = kernel_nd_permute; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_, + arg.out_grid_desc_, + arg.in_dev_buffer_, + arg.out_dev_buffer_, + arg.elementwise_op_, + arg.block_2_tile_map_); + return elapsed_time; + } + + float Run(const BaseArgument* arg, + const StreamConfig& stream_config = StreamConfig{}) override final + { + const auto* const argument = dynamic_cast(arg); + if(!argument) + { + return NAN; + } + + return Run(*argument, stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) { + return math::integer_divide_ceil(length, tile_length) * tile_length; + }; + + constexpr auto IsScalarPerVectorValid = + [](index_t length, index_t stride, index_t scalar_per_vector) { + if(stride == 1 && length % scalar_per_vector == 0) + { + return true; + } + else if(stride != 1 && scalar_per_vector == 1) + { + return true; + } + + return false; + }; + + return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim], + arg.in_strides_[SrcVectorDim], + SrcScalarPerVector) && + IsScalarPerVectorValid( + GetPaddedLength(arg.in_lengths_[SrcVectorDim], + (SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), + arg.in_strides_[SrcVectorDim], + SrcScalarPerVector) && + IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim], + arg.out_strides_[DstVectorDim], + DstScalarPerVector) && + IsScalarPerVectorValid( + GetPaddedLength(arg.out_lengths_[DstVectorDim], + (DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), + arg.in_strides_[DstVectorDim], + DstScalarPerVector) && + GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_); + }; + + // override methods inherited from 'BaseOperator' + bool IsSupportedArgument(const BaseArgument* arg) override final + { + const auto* const argument = dynamic_cast(arg); + if(!argument) + { + return false; + } + + return IsSupportedArgument(*argument); + } + + // override methods inherited from 'DevicePermute' + std::unique_ptr + MakeArgumentPointer(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) override final + { + return std::make_unique(in_lengths, + in_strides, + out_lengths, + out_strides, + in_dev_buffer, + out_dev_buffer, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override final + { + return std::make_unique(); + }; + + // other constructor methods + template + static std::enable_if_t, Argument> + MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v) + { + return Argument{std::forward(args)...}; + } + + static std::enable_if_t, Invoker> + MakeInvoker() noexcept(std::is_nothrow_default_constructible_v) + { + return Invoker{}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bfde40cda2a5369015b08de568f72c786295848f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using IndexDataType = int32_t; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr index_t InSrcOutDstVectorDim = + 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is + // not reduced. + + static constexpr ck::index_t ReduceM_BlockTileSize = + ReduceMThreadClusterSize * ReduceMThreadSliceSize; + static constexpr ck::index_t ReduceK_BlockTileSize = + ReduceKThreadClusterSize * ReduceKThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) + { + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = window_spatial_lengths[0]; + const index_t X = window_spatial_lengths[1]; + + const index_t ConvStrideH = window_strides[0]; + const index_t ConvStrideW = window_strides[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ReduceMRaw = N * Ho * Wo * C; + const index_t ReduceMPad = + math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw; + + const index_t ReduceKRaw = Y * X; + const index_t ReduceKPad = + math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw; + + // A[ReduceM, ReduceK] + const auto in_grid_desc_n_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_grid_desc_reducemraw_reducekraw = + transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad), + make_right_pad_transform(ReduceKRaw, ReduceKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const auto out_grid_desc_reducemraw = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C)); + + const auto out_grid_desc_reducem = transform_tensor_descriptor( + out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = decltype( + MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + // TODO + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + int* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array& input_spatial_lengths, + std::array& window_spatial_lengths, + std::array& output_spatial_lengths, + std::array& window_strides, + std::array& input_left_pads, + std::array& input_right_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + invariant_lowest_length_ = C; + reduce_lowest_length_ = window_spatial_lengths[1]; + + int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + int* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + ck::index_t invariant_lowest_length_; + ck::index_t reduce_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) + { + return (false); + } + + return (true); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + ck::index_t N, + ck::index_t C, + std::array input_spatial_lengths, + std::array window_spatial_lengths, + std::array output_spatial_lengths, + std::array window_strides, + std::array input_left_pads, + std::array input_right_pads) override + { + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + N, + C, + input_spatial_lengths, + window_spatial_lengths, + output_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ","; + str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ","; + str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5dc051be3cb7f23fef076921e826a7e55e017d76 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those +// of reduce dims +template +std::pair get_2d_lengths(const std::vector& inLengths) +{ + static_assert(Rank <= 6, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +template +std::pair get_2d_lengths(const std::array& inLengths) +{ + static_assert(Rank <= 6, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +// helper functions using variadic template arguments +template +auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) +{ + return make_tuple(static_cast(lengths[Ns])...); +}; + +template +auto make_tuple_from_array(const std::vector& lengths, Number) +{ + static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); + + constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; + + return make_tuple_from_array_and_index_seq(lengths, index_seq); +}; + +template +std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, + const std::vector& reduceDims) +{ + std::vector newLengthsStrides; + + assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + return newLengthsStrides; +}; + +template +std::array +shuffle_tensor_dimensions(const std::array& origLengthsStrides, + const std::array& reduceDims) +{ + std::array newLengthsStrides; + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + int pos = 0; + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides[pos++] = origLengthsStrides[i]; + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides[pos++] = origLengthsStrides[i]; + }; + + return newLengthsStrides; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93855eb33e1ba2a35e8fb31a9cc6994338d26aab --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -0,0 +1,543 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiBlock + : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType::value, + "The OutDataType must support the specified OutMemoryDataOperation!"); + + static_assert(!use_multiblock || (use_multiblock && !OutputIndex), + "MultiBlock reduction can only be used when outputing index is not required"); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto MakeDst1dDescriptorForBufferSet(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + const IndexDataType* in_index_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + in_index_dev_{in_index_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + if(Rank != inLengths.size() || Rank != inStrides.size() || + NumReduceDim != reduceDims.size()) + { + throw std::runtime_error( + "One of inLengths/inStrides/reduceDims has invalid size!" + "\nExpected size inLengths: " + + std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) + + ", reduceDims: " + std::to_string(NumReduceDim) + + "\nBut have inLengths: " + std::to_string(inLengths.size()) + + ", inStrides: " + std::to_string(inStrides.size()) + + ", reduceDims: " + std::to_string(reduceDims.size())); + } + + for(std::size_t i = 0; i < reduceDims.size(); ++i) + { + if(reduceDims[i] < 0 || reduceDims[i] >= Rank) + { + throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!" + "\nHave reduceDims[" + + std::to_string(i) + + "]: " + std::to_string(reduceDims[i])); + } + } + + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::array inLengths_; + std::array inStrides_; + std::array outLengths_; + std::array outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + const IndexDataType* in_index_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m = + DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet( + arg.outLengths_, arg.outStrides_); + + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + using OutGridDesc_M_2 = decltype(out_grid_desc_m_2); + + using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock; + + const auto kernel_main = kernel_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + const auto identityVal = + ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); + + const auto kernel_pre = + kernel_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + out_grid_desc_m_2, + arg.out_dev_, + identityVal); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.in_index_dev_, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* pArg) + { + if constexpr(use_multiblock) + { + if(static_cast(pArg->beta_) != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->invariant_lowest_length % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + // if(pArg->reduce_total_length / KThreadSliceSize < 2) + // return (false); + }; + + return (true); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(in_index_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..05e14f080efe44d1715acc9a506e4c4bf3c4a9bd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWise + : public DeviceReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + float alpha, + float beta, + const InDataType* in_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::array inLengths_; + std::array inStrides_; + std::array outLengths_; + std::array outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + + float avg_time = 0; + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.alpha_, + arg.in_dev_, + nullptr, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + float alpha, + float beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + (void)in_index_dev; + + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8630a2c6e227c8b929a3644e75af8f849d32e92c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmaxImpl : public DeviceSoftmax +{ + static constexpr index_t kRank = Rank; + static constexpr index_t kNumReduceDim = NumReduceDim; + static constexpr index_t kNumInvariantDim = Rank - NumReduceDim; + + virtual index_t GetRank() const override { return kRank; } + + virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk; + + using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + AccDataType alpha, + AccDataType beta, + const InDataType* in_dev, + OutDataType* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) + : alpha_{alpha}, + beta_{beta}, + in_dev_{in_dev}, + out_dev_{out_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + if(Rank != inLengths.size() || Rank != inStrides.size() || + NumReduceDim != reduceDims.size()) + { + throw std::runtime_error( + "One of inLengths/inStrides/reduceDims has invalid size!" + "\nExpected size inLengths: " + + std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) + + ", reduceDims: " + std::to_string(NumReduceDim) + + "\nBut have inLengths: " + std::to_string(inLengths.size()) + + ", inStrides: " + std::to_string(inStrides.size()) + + ", reduceDims: " + std::to_string(reduceDims.size())); + } + + for(std::size_t i = 0; i < reduceDims.size(); ++i) + { + if(reduceDims[i] < 0 || reduceDims[i] >= Rank) + { + throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!" + "\nHave reduceDims[" + + std::to_string(i) + + "]: " + std::to_string(reduceDims[i])); + } + } + + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = inLengths_[NumInvariantDim - 1]; + + blkGroupSize = 1; + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + } + + std::vector inLengths_; + std::vector inStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + + InElementwiseOp in_elementwise_op_; + AccElementwiseOp acc_elementwise_op_; + + index_t invariant_lowest_length_; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + + bool sweep_once = + in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + const auto kernel_main = sweep_once ? kernel_softmax + : kernel_softmax; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m_k, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.beta_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(InSrcVectorDim == 0) + { + if constexpr(kNumInvariantDim == 0) + { + return false; + } + else + { + if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + { + return false; + } + if(arg.invariant_lowest_length_ % InSrcVectorSize != 0) + { + return false; + } + } + } + else + { + if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + { + return false; + } + if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0) + { + return false; + } + } + + // To improve + if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) + { + return false; + } + + if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0) + { + return false; + } + + return true; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const AccDataType alpha, + const AccDataType beta, + const InDataType* in_dev, + OutDataType* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) + { + return Argument{inLengths, + inStrides, + reduceDims, + alpha, + beta, + in_dev, + out_dev, + in_elementwise_op, + acc_elementwise_op}; + }; + + // + // @brief Makes a pointer to Argument class. + // + // @param[in] inLengths Input tensor extent(s) from high to low dimension + // @param[in] inStrides Input tensor stride(s) from high to low dimension + // @param[in] reduceDims The dimension(s) the normalization operation is applied + // @param[in] alpha Typeless pointer in host memory storing the alpha scaling + // value as type AccDataType + // @param[in] beta Typeless pointer in host memory storing the beta scaling + // value as type AccDataType + // @param[in] in_dev Typeless const pointer in device memory storing the input + // tensor + // @param out_dev Typeless pointer in device memory storing the output tensor + // @param[in] in_elementwise_op The input elementwise operation. + // @param[in] acc_elementwise_op The accumulation elementwise operation. + // + // @return Unique pointer to the Argument class. + // + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + const void* alpha, + const void* beta, + const void* in_dev, + void* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + reduceDims, + *static_cast(alpha), + *static_cast(beta), + static_cast(in_dev), + static_cast(out_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceSoftmax<" + << Rank << "," << NumReduceDim << "," << BlockSize << "," + << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << "," + << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << "," + << "InSrcVectorDim_" << InSrcVectorDim + << "_InSrcVectorSize_" << InSrcVectorSize + << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f2b46edd3c1f70edbd6247c1919fad9e0d6b1ad --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator +{ + + static auto MakeOutputDescriptor(const index_t index_length, const index_t rows) + { + return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows)); + } + + struct Argument : public BaseArgument + { + Argument(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const ck::index_t NumRows, + const ck::index_t EmbeddingDim, + const ck::index_t IndexLength, + const AccDataType epsilon) + : p_out_(p_out), + p_emb_a_(p_emb_a), + p_emb_b_(p_emb_b), + p_emb_c_(p_emb_c), + p_index_a_(p_index_a), + p_index_b_(p_index_b), + p_index_c_(p_index_c), + p_gamma_(p_gamma), + p_beta_(p_beta), + NumRows_(NumRows), + EmbeddingDim_(EmbeddingDim), + IndexLength_(IndexLength), + epsilon_(epsilon) + { + grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; + } + + OutType* p_out_; + const EmbType* p_emb_a_; + const EmbType* p_emb_b_; + const EmbType* p_emb_c_; + const IndexType* p_index_a_; + const IndexType* p_index_b_; + const IndexType* p_index_c_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + ck::index_t NumRows_; + ck::index_t EmbeddingDim_; + ck::index_t IndexLength_; + AccDataType epsilon_; + + size_t grid_size_; + }; + + virtual std::unique_ptr MakeArgumentPointer(void* p_out, + const void* p_emb_a, + const void* p_emb_b, + const void* p_emb_c, + const void* p_index_a, + const void* p_index_b, + const void* p_index_c, + const void* p_gamma, + const void* p_beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + const AccDataType epsilon) + { + return std::make_unique(reinterpret_cast(p_out), + reinterpret_cast(p_emb_a), + reinterpret_cast(p_emb_b), + reinterpret_cast(p_emb_c), + reinterpret_cast(p_index_a), + reinterpret_cast(p_index_b), + reinterpret_cast(p_index_c), + reinterpret_cast(p_gamma), + reinterpret_cast(p_beta), + NumRows, + EmbeddingDim, + IndexLength, + epsilon); + } + + using GridwiseSparseEmbedding = + GridwiseSparseEmbedding3ForwardLayernorm; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_); + const auto kernel_main = + kernel_sparse_embedding3_forward_layernorm; + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + arg.p_out_, + arg.p_emb_a_, + arg.p_emb_b_, + arg.p_emb_c_, + arg.p_index_a_, + arg.p_index_b_, + arg.p_index_c_, + arg.p_gamma_, + arg.p_beta_, + out_desc, + arg.epsilon_); + + return (avg_time); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* p_arg) + { + return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" << + DimClusterSize << "x" << RowClusterSize << "_" << + DimPerBlock << "x" << RowPerBlock << "_" << + DimThreadSize << "x" << RowVectorSize; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/masking_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea0f5897a757c408047b20e0cd4285e13f67c4f5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct MaskingSpecialization +{ + MaskDisabled, + MaskOutUpperTriangle +}; + +inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) +{ + switch(s) + { + case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; + case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle"; + default: return "Unrecognized specialization!"; + } +} + +struct MaskDisabledPredicate +{ + __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const + { + return false; + }; + + __host__ __device__ constexpr bool + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + { + return false; + } +}; + +struct MaskOutUpperTrianglePredicate +{ + __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const + { + return operator()(m + m_tile - 1, n); + } +}; + +// to track the points which need to be set to -inf on C0 +// Note: no need to reset M padding value, because they will not be stored out. +template +struct C0MatrixMask_impl +{ + C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} + + __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const + { + return n >= NRaw_; + } + + __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const + { + return predicate_(m, n) || IsNOutOfBound(n); + } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const + { + return predicate_.IsTileSkippable(m, n, m_tile, n_tile); + } + + private: + // index_t MRaw_; + index_t NRaw_; + MaskOutPredicate predicate_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/matrix_padder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70e61bc7728f9638e5bef874ab0d5caad2db5028 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + typename DoPads> // Sequence +__host__ __device__ constexpr auto +PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::Size(); + + static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto MRaw = desc.GetLength(idim); + + const auto MPerTile = tile_lengths[idim]; + + const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile; + + const auto MPad = M - MRaw; + + const bool DoPadM = DoPads::At(idim); + + const auto MTransform = conditional_expr(make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(MRaw)); + + return MTransform; + }, + Number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss); +} + +// M/N/K/OPerTileType could be index_t or Number<> +template +struct GemmGemmPadder +{ + // TODO: hard to scale; use mask instead + static constexpr bool PadM = + GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::MOPadding || GemmSpec == GemmSpecialization::MNOPadding || + GemmSpec == GemmSpecialization::MKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadN = + GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::MNOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadK = + GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KOPadding || GemmSpec == GemmSpecialization::MKOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadO = + GemmSpec == GemmSpecialization::OPadding || GemmSpec == GemmSpecialization::MOPadding || + GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::KOPadding || + GemmSpec == GemmSpecialization::MNOPadding || GemmSpec == GemmSpecialization::MKOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + + // A[M, K] + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + // B[K, N] + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + // B1[Gemm1N, Gemm1K] = B1[O, N] + template + __host__ __device__ constexpr auto + PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence{}); + } + + // C[M, Gemm1N] = C[M, O] + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; + OPerTileType OPerTile_; +}; + +// M/N/KPerTileType could be index_t or Number<> +template +struct GemmPadder +{ + static constexpr bool PadM = + (GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding); + static constexpr bool PadN = + (GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding); + static constexpr bool PadK = + (GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding); + + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; + +// Alias of GemmPadder; to deprecate +template +struct MatrixPadder : public GemmPadder +{ +}; + +// M/N/KPerTileType could be index_t or Number<> +template +struct GemmPadder_v2 +{ + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; + +// M/N/KPerTileType could be index_t or Number<> +template +struct MatrixPadder_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + const auto MRaw = a_desc_mraw_kraw.GetLength(I0); + const auto KRaw = a_desc_mraw_kraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_; + const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(PadM && PadK) + { + // pad both M and K + return transform_tensor_descriptor(a_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadM && (!PadK)) + { + // pad M, but not K + return transform_tensor_descriptor( + a_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadM) && PadK) + { + // pad K, but not M + return transform_tensor_descriptor( + a_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or K + return a_desc_mraw_kraw; + } + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + const auto NRaw = b_desc_nraw_kraw.GetLength(I0); + const auto KRaw = b_desc_nraw_kraw.GetLength(I1); + + const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_; + const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(PadN && PadK) + { + // pad both N and K + return transform_tensor_descriptor(b_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadN && (!PadK)) + { + // pad N, but not K + return transform_tensor_descriptor( + b_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadN) && PadK) + { + // pad K, but not N + return transform_tensor_descriptor( + b_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad N or K + return b_desc_nraw_kraw; + } + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + const auto MRaw = c_desc_mraw_nraw.GetLength(I0); + const auto NRaw = c_desc_mraw_nraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_; + const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(PadM && PadN) + { + // pad M and N + return transform_tensor_descriptor(c_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadM && (!PadN)) + { + // pad M, but not N + return transform_tensor_descriptor( + c_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadM) && PadN) + { + // pad N, but not M + return transform_tensor_descriptor( + c_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_desc_mraw_nraw; + } + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d35318357a9bcc6fb087c92e27dcf435ff60796b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// FIXME: can it be replaced with ck::Tuple? +#include + +namespace ck { + +// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their +// respective functor classes. +// The boolean member "indexable" are also provided in reduce_binary_operactor for +// easier checking by the upper-layer codes in the kernels. + +template +struct reduce_binary_operator; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Mul; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Min; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Max; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::AMax; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary +// functor classes. +// The two unary functors are called before and afer the Reduction is executed respectively +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b44427411f9a28511fd16f6b47ee37fa84a0e2d3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -0,0 +1,417 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +namespace gemm { + +struct RowMajor : public BaseTensorLayout +{ + static constexpr const char* name = "RowMajor"; +}; + +struct ColumnMajor : public BaseTensorLayout +{ + static constexpr const char* name = "ColumnMajor"; +}; +} // namespace gemm + +namespace convolution { + +// input tensor +// packed NCW/NCHW/NCDHW +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct NCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCHW"; +}; + +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +// packed GNCW/GNCHW/GNCDHW +struct GNCW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCW"; +}; + +struct GNCHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCHW"; +}; + +struct GNCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCDHW"; +}; + +// input tensor +// packed NWC/NHWC/NDHWC +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct NHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWC"; +}; + +struct NDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWC"; +}; + +// input tensor +// packed GNWC/GNHWC/GNDHWC +struct GNWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNWC"; +}; + +struct GNHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWC"; +}; + +struct GNDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWC"; +}; + +// for input bias +struct GC : public BaseTensorLayout +{ + static constexpr const char* name = "GC"; +}; + +// input tensor +// packed NWGC/NHWGC/NDHWGC +struct NWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NWGC"; +}; + +struct NHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGC"; +}; + +struct NDHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGC"; +}; + +// input tensor +// strided layout +struct G_NW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_C"; +}; + +struct G_NHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_C"; +}; + +struct G_NDHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_C"; +}; + +// for input bias +struct G_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_C"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct KCYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCYX"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct GKCX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCX"; +}; + +struct GKCYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCYX"; +}; + +struct GKCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCZYX"; +}; + +// weight tensor +// packed KXC/KYXC/KZYXC +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct KYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXC"; +}; + +struct KZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXC"; +}; + +// weight tensor +// packed GKXC/GKYXC/GKZYXC +struct GKXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKXC"; +}; + +struct GKYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKYXC"; +}; + +struct GKZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKZYXC"; +}; + +// weight tensor +// packed KXGC/KYXGC/KZYXGC +struct KXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KXGC"; +}; + +struct KYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXGC"; +}; + +struct KZYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXGC"; +}; + +// weight tensor +// strided +struct G_K_X_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_X_C"; +}; + +struct G_K_YX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_YX_C"; +}; + +struct G_K_ZYX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_ZYX_C"; +}; + +// output tensor +// packed NKW/NKHW/NKDHW +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +struct NKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKHW"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; +}; + +// output tensor +// packed GNKW/GNKHW/GNKDHW +struct GNKW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKW"; +}; + +struct GNKHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKHW"; +}; + +struct GNKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKDHW"; +}; + +// output tensor +// packed NWK/NHWK/NDHWK +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWK"; +}; + +struct NDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWK"; +}; + +// output tensor +// packed GNWK/GNHWK/GNDHWK +struct GNWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNWK"; +}; + +struct GNHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWK"; +}; + +struct GNDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWK"; +}; + +// for output bias +struct GK : public BaseTensorLayout +{ + static constexpr const char* name = "GK"; +}; + +// output tensor +// packed NWGK/NHWGK/NDHWGK +struct NWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NWGK"; +}; + +struct NHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGK"; +}; + +struct NDHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGK"; +}; + +// output tensor +// strided layout +struct G_NW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_K"; +}; + +struct G_NHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_K"; +}; + +struct G_NDHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_K"; +}; + +// for output bias +struct G_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_K"; +}; + +// K-reduced output tensor (packed) +struct GNW : public BaseTensorLayout +{ + static constexpr const char* name = "GNW"; +}; + +struct GNHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNHW"; +}; + +struct GNDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHW"; +}; + +// K-reduced output tensor (packed) +struct NWG : public BaseTensorLayout +{ + static constexpr const char* name = "NWG"; +}; + +struct NHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NHWG"; +}; + +struct NDHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWG"; +}; + +// K-reduced output tensor (strided) +struct G_NW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW"; +}; + +struct G_NHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW"; +}; + +struct G_NDHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW"; +}; + +} // namespace convolution + +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + +} // namespace tensor_layout +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0ec0df2c9bbbb2aac987d681a36d5b00a8349c97 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct TensorSpecialization +{ + Default, + Packed +}; + +inline std::string getTensorSpecializationString(const TensorSpecialization& s) +{ + switch(s) + { + case TensorSpecialization::Default: return "Default"; + case TensorSpecialization::Packed: return "Packed"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/welford_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c909b767d4602914660d045b9aaf5a2307ea57b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GetReduceCountPerThreadForBlockwiseWelford +{ + GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration, + long_index_t reduce_length) + : numBlockTileIteration_{numBlockTileIteration} + { + count_in_last_tile_ = reduce_length % K_BlockTileSize; + }; + + __device__ index_t operator()(index_t thread_k_cluster_id) const + { + if(count_in_last_tile_ == 0) + return (KThreadSliceSize * numBlockTileIteration_); + else + { + index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize; + index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIteration_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice); + else + return (KThreadSliceSize * (numBlockTileIteration_ - 1)); + }; + }; + + index_t numBlockTileIteration_; + index_t count_in_last_tile_; +}; + +template +struct GetReduceCountPerThreadForMultiblockWelford +{ + GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize, + index_t numBlockTileIteration, + long_index_t reduce_length) + : blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration} + { + last_block_reduce_length_ = + reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1); + numBlockTileIterationByLastBlock_ = + (last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + __device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const + { + if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ || + block_local_id < blkGroupSize_ - 1) + return (KThreadSliceSize * numBlockTileIteration_); + + index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize; + + if(count_in_last_tile == 0) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else + { + index_t num_complete_slice = count_in_last_tile / KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) + + count_in_last_tile); + else + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1)); + }; + }; + + index_t blkGroupSize_; + index_t numBlockTileIteration_; + + index_t last_block_reduce_length_; + index_t numBlockTileIterationByLastBlock_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a4053b1f36221053bb2ef558dea469c64ce36e81 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 + type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 + x1; + }; +}; + +struct Subtract +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp - x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 - x1; + }; +}; + +struct Bilinear +{ + Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = type_convert(alpha_) * x0 + type_convert(beta_) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); + }; + + float alpha_; + float beta_; +}; + +struct AddRelu +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > 0.0 ? a : 0.0; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + const float a = x0 + x1; + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(int& y, const int& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > 0 ? a : 0; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > 0 ? a : 0; + }; +}; + +struct AddHardswish +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + double a = x0 + x1; + double b = a + 3.0; + double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + 3.0f; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; +}; + +// C = A * B +// E = FastGelu(C + D) +struct AddFastGelu +{ + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + __host__ __device__ static constexpr float GetFastGeLU(float x) + { + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = exp(-u); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); + return x * cdf; + } + + template + static inline constexpr bool is_valid_param_type_v = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const + { + static_assert(is_valid_param_type_v && is_valid_param_type_v && + is_valid_param_type_v); + + const float y = GetFastGeLU(type_convert(c) + type_convert(d)); + + e = type_convert(y); + } + + template + __host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const + { + static_assert(is_valid_param_type_v); + + e = GetFastGeLU(c + type_convert(d)); + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b66107a525569db2c8c706d7826221a686abf39e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/quantization_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// Need to ensure compiler will fail if there is no matching candidate, instead of compiler +// siliently do implicit type conversion +// +// Method 1: +// +// struct ExampleElementwiseOp +// { +// template +// __host__ __device__ constexpr void +// operator()(Y&, const X) const; +// +// template<> +// __host__ __device__ constexpr void +// operator()(half_t& y, const half_t& x) const +// { +// } +// }; +// +// Method 2: +// +// template +// struct ExampleElementwiseOp; +// +// template <> +// struct ExampleElementwiseOp +// { +// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const +// { +// } +// }; + +struct AddReluAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + half_t a = x0 + x1; + half_t b = a > 0 ? a : 0; + y = b + x2; + } + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& y, const int8_t& x0, const int8_t& x1, const int8_t& x2) const + { + int32_t a = x0 + x1; + int32_t b = a > 0 ? a : 0; + int32_t c = b + x2; + y = c; + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + __host__ __device__ constexpr void operator()( + int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const + { + int32_t a = x0 + x1; + int32_t b = a > 0 ? a : 0; + int32_t c = b + x2; + y = c; + } +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +}; + +struct AddHardswishAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } +}; + +// C = A * B +// E = C + D0 + D1 +struct AddAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const + { + // Only support floating so far + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + const C y = c + type_convert(d0) + type_convert(d1); + e = type_convert(y); + } +}; + +// C = A * B +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu +{ + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + __host__ __device__ static constexpr float GetFastGeLU(float x) + { + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = exp(-u); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); + return x * cdf; + } + + template + static inline constexpr bool is_valid_param_type_v = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v +#endif + ; + + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const + { + static_assert(is_valid_param_type_v && is_valid_param_type_v && + is_valid_param_type_v && is_valid_param_type_v); + + const float y = + GetFastGeLU(type_convert(c) + type_convert(d0) + type_convert(d1)); + + e = type_convert(y); + } +}; + +struct Normalize +{ + // FIXME: is double absolutely necessary? + Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} + + template + __host__ __device__ constexpr void operator()(T1& y, + const T1& x, + const T2& mean, + const T2& mean_square, + const T3& gamma, + const T3& beta) const; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, + const half_t& x, + const float& mean, + const float& mean_square, + const half_t& gamma, + const half_t& beta) const + { + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + + float tmp_x = type_convert(x); + float tmp_gamma = type_convert(gamma); + float tmp_beta = type_convert(beta); + + float tmp_y = + ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * tmp_gamma + + tmp_beta; + + y = type_convert(tmp_y); + }; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const + { + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; + }; + + template <> + __host__ __device__ constexpr void operator()(double& y, + const double& x, + const double& mean, + const double& mean_square, + const double& gamma, + const double& beta) const + { + using ck::math::sqrt; + + double variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; + }; + + // FIXME: is double absolutely necessary? + double epsilon_; +}; + +template +struct UnaryTypeConvert; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const + { + y = ck::type_convert(x); + } +}; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const + { + y = ck::type_convert(x); + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/quantization_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f2c2f877314d07f90438bc59debf30c8aa84b8f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -0,0 +1,124 @@ +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +template +struct Activation_Mul_Clamp +{ + Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float x_fp32 = ck::type_convert(x); + activationOp_(x_fp32, x_fp32); + float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void operator()(float& y, const int32_t& x) const + { + // We might type_convert to int8 after lambda in someplace + float x_fp32 = ck::type_convert(x); + activationOp_(x_fp32, x_fp32); + y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); + } + + float requantScale_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +template +struct Activation_Mul2_Clamp +{ + Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + Activation activationOp_; +}; + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +template +struct Add_Activation_Mul_Clamp +{ + Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const + { + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float requantScale_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +template +struct Add_Activation_Mul2_Clamp +{ + Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + Activation activationOp_; +}; + +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +template +struct Add_Mul_Activation_Mul_Clamp +{ + Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) + : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const + { + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = requantScale1_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float requantScale1_; + float requantScale2_; + Activation activationOp_; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbdfe9262fa89190b8ec8361bc3f44bf78ba3015 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/math_v2.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int8_t& y, const int32_t& x) const + { + y = type_convert(x); + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + __host__ __device__ void operator()(int4_t& y, const int4_t& x) const + { + y = x; + } +#endif +}; + +struct UnaryConvert +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + y = type_convert(x); + } +}; + +struct Scale +{ + __host__ __device__ Scale(float scale) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = scale_ * x; + }; + + float scale_; +}; + +struct ScaleAndResetNaNToMinusInfinity +{ + __host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = ck::math::isnan(x) ? -ck::NumericLimits::Infinity() : scale_ * x; + }; + + float scale_; +}; + +struct UnaryDivide +{ + __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {} + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || is_same_v +#endif + , + "Data type is not supported by this operation!"); + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::abs(x); + }; +}; + +struct UnarySqrt +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sqrt(x); + }; +}; + +struct Relu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_f32 = ck::type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = ck::type_convert(y_f32); + } +}; + +// Y = FastGelu(X) +struct FastGelu +{ + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + __host__ __device__ static constexpr float GetFastGeLU(float x) + { + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = exp(-u); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); + return x * cdf; + } + + template + static inline constexpr bool is_valid_param_type_v = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v +#endif + ; + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(is_valid_param_type_v && is_valid_param_type_v); + + const float tmp_y = GetFastGeLU(type_convert(x)); + y = type_convert(tmp_y); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+erf(x/sqrt(2))) +struct Gelu +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = 0.5f * x * (1.f + erf(float(0.70710678118f * x))); + } + + template <> + __host__ __device__ void operator()(ck::half_t& y, + const ck::half_t& x) const + { + y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); + } +}; + +struct Sigmoid +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = 1 / (ck::type_convert(1) + exp(-x)); + }; + + int32_t divider_ = 1; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a72a4ee068f2bbcb1a915df0368c9b05caa298f9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_second_half_batchnorm_backward_final( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + index_t blkgroup_size, + long_index_t reduce_size, + index_t num_xy_k_block_tile_iteration, + index_t num_dscale_dbias_k_block_tile_iteration, + const DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + const DscaleDbiasDataType* const __restrict__ p_reduce_dbias, + const MeanVarDataType* const __restrict__ p_mean, + const MeanVarDataType* const __restrict__ p_inv_var, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) +{ + GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + dx_grid_desc_m_k, + dscale_dbias_grid_desc_m_k, + mean_var_grid_desc_m, + scale_grid_desc_m, + bias_grid_desc_m, + blkgroup_size, + reduce_size, + num_xy_k_block_tile_iteration, + num_dscale_dbias_k_block_tile_iteration, + p_reduce_dscale, + p_reduce_dbias, + p_mean, + p_inv_var, + p_x, + p_dy, + p_scale, + dy_elementwise_op, + p_dx, + p_dscale, + p_dbias); +}; + +template +struct GridwiseReduceSecondHalfBatchNormBackwardFinal +{ + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_1 = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Two of the steps of Multiblock BatchNorm Backward + // Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& dy_grid_desc_m_k, + const XYGridDesc_M_K& dx_grid_desc_m_k, + const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m, + index_t blkgroup_size, + long_index_t reduce_size, + index_t num_xy_k_block_tile_iteration, + index_t num_dscale_dbias_k_block_tile_iteration, + const DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + const DscaleDbiasDataType* const __restrict__ p_reduce_dbias, + const MeanVarDataType* const __restrict__ p_mean, + const MeanVarDataType* const __restrict__ p_inv_var, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) + { + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + reduce_dscale_thread_buf; + StaticBuffer + reduce_dbias_thread_buf; + + StaticBuffer dscale_thread_buf; + StaticBuffer dbias_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + dy_thread_buf; + StaticBuffer + dx_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer + inv_var_thread_buf; + StaticBuffer scale_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // clang-format off + // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + + auto threadwise_dscale_dbias_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + dscale_dbias_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_dscale_dbias_store_m = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DscaleDbiasDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + const auto reduce_dscale_global_buf = make_dynamic_buffer( + p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); + + const auto reduce_dbias_global_buf = make_dynamic_buffer( + p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); + + auto dscale_global_buf = make_dynamic_buffer( + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + auto dbias_global_buf = make_dynamic_buffer( + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + constexpr auto dscale_dbias_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dscale_thread_buf(I) = type_convert(0.0f); + dbias_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dscale_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dscale_thread_buf); + + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dbias_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dbias_thread_buf); + + ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); + ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); + + threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, + dscale_dbias_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); + }); + + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); + + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); + + // clang-format off + // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) + // clang-format on + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_buf = make_dynamic_buffer( + p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto mean_global_buf = make_dynamic_buffer( + p_mean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + AccDataType inv_reduce_size = + type_convert(1.0) / type_convert(reduce_size); + + for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM]; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + AccDataType tmpVal = norm_x * dscale_thread_buf[iM]; + + dx_thread_buf(Number{}) = + multiplier * + (type_convert(reduce_size) * dy_thread_buf[Number{}] - + dbias_thread_buf[iM] - tmpVal); + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k); + } + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08cb0dd191303813a9283b19e38e2952fa0ccf27 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_welford_first_half( + const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) +{ + GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +struct GridwiseMultiblockWelfordFirstHalf +{ + static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) || + (XSrcCountSrcVectorDim == 1 && + KThreadSliceSize % XSrcCountSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward. + // clang-format on + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) + { + StaticBuffer + x_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_welford_count_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + auto welford_mean_global_val_buf = make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_var_global_val_buf = make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + welford_count_thread_buf(I) = threadwise_welford.cur_count_; + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_mean_thread_buf, + mean_var_count_grid_desc_m_g, + welford_mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_var_thread_buf, + mean_var_count_grid_desc_m_g, + welford_var_global_val_buf); + + threadwise_welford_count_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_count_thread_buf, + mean_var_count_grid_desc_m_g, + welford_count_global_val_buf); + }; + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp new file mode 100644 index 0000000000000000000000000000000000000000..548d7fd40ac70fe7fde5e5973a177600b0941c5c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp @@ -0,0 +1,571 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_welford_second_half_batchnorm_forward_final( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + blkgroup_size, + num_xy_k_block_tile_iteration, + num_mean_var_count_k_block_tile_iteration, + epsilon, + p_in_welford_mean, + p_in_welford_variance, + p_in_welford_count, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseWelfordSecondHalfBatchNormForwardFinal +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_1 = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + + { + using ck::math::sqrt; + + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + y_thread_buf; + + StaticBuffer scale_thread_buf; + StaticBuffer bias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + const auto welford_mean_global_val_buf = make_dynamic_buffer( + p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_var_global_val_buf = make_dynamic_buffer( + p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + // Step 1: do final welford reduction to get mean and variance + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // Step 2: do normalization and output y + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load_m = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + welford_mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + welford_var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance) + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42b7e172b239e69ae0ca840dd80d8bb203a5480f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_welford_second_half_reduce_first_half( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const DyElementwiseOp dy_elementwise_op, + MeanVarDataType* const __restrict__ p_out_welford_mean, + MeanVarDataType* const __restrict__ p_out_welford_inv_variance, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + DscaleDbiasDataType* const __restrict__ p_reduce_dbias) +{ + GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + mean_var_grid_desc_m, + mean_var_count_grid_desc_m_k, + dscale_dbias_grid_desc_m_g, + blkgroup_size, + num_xy_k_block_tile_iteration, + num_mean_var_count_k_block_tile_iteration, + epsilon, + haveSavedMeanInvVar, + p_savedMean, + p_savedInvVar, + p_in_welford_mean, + p_in_welford_variance, + p_in_welford_count, + dy_elementwise_op, + p_out_welford_mean, + p_out_welford_inv_variance, + p_x, + p_dy, + p_reduce_dscale, + p_reduce_dbias); +}; + +template +struct GridwiseWelfordSecondHalfReduceFirstHalf +{ + static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0) || + (XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceSrcDesc_M_1 = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Two of the steps of Multiblock BatchNorm Backward + // Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance) + // Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& dy_grid_desc_m_k, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const DyElementwiseOp dy_elementwise_op, + MeanVarDataType* const __restrict__ p_out_welford_mean, + MeanVarDataType* const __restrict__ p_out_welford_inv_variance, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + DscaleDbiasDataType* const __restrict__ p_reduce_dbias) + { + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer& mean_thread_buf = + welford_mean_thread_buf; + StaticBuffer& + inv_var_thread_buf = welford_var_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + dy_thread_buf; + + // buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction + StaticBuffer + tmp1_thread_buf; + + StaticBuffer + reduce_dscale_thread_buf; + StaticBuffer + reduce_dbias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // clang-format off + // Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance) + // clang-format on + + if(haveSavedMeanInvVar) + { + const auto mean_global_buf = make_dynamic_buffer( + p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_inv_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + } + else + { + const auto welford_mean_global_buf = make_dynamic_buffer( + p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_var_global_buf = make_dynamic_buffer( + p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_count_global_buf = make_dynamic_buffer( + p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_mean_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_var_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_count_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow( + mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run(welford_mean_thread_buf(I), + welford_var_thread_buf(I), + welford_count_thread_buf(I)); + }); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_var_thread_buf(I) = + type_convert(1.0) / sqrt(welford_var_thread_buf[I] + epsilon); + }); + + if(block_local_id == 0 && thread_k_cluster_id == 0) + { + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto mean_global_buf = make_dynamic_buffer( + p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto inv_var_global_buf = make_dynamic_buffer( + p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize()); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf, + mean_var_grid_desc_m, + inv_var_global_buf); + }; + }; + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + reduce_dscale_thread_buf(I) = type_convert(0); + reduce_dbias_thread_buf(I) = type_convert(0); + }); + + // clang-format off + // Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + + for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; + }); + }); + + ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf); + ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I)); + }); + + auto threadwise_dscale_dbias_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto reduce_dscale_global_buf = make_dynamic_buffer( + p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize()); + + auto reduce_dbias_global_buf = make_dynamic_buffer( + p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize()); + + if(thread_k_cluster_id == 0) + { + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dscale_thread_buf, + dscale_dbias_grid_desc_m_g, + reduce_dscale_global_buf); + + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dbias_thread_buf, + dscale_dbias_grid_desc_m_g, + reduce_dbias_global_buf); + }; + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp new file mode 100644 index 0000000000000000000000000000000000000000..460b684346b531db5381543372b01d6fe3e2f52e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -0,0 +1,546 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// Blocks of row-vectors +template +struct BlockToCTileMap_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1) + : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1)); + UnderlyingMap underlying_map_; +}; + +// 2D slices of row-vectors in 3D space +template +struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1, + index_t KSplit = 1) + : c_grid_desc_m_n_(c_grid_desc_m_n), + M01_(M01), + N01_(N01), + KSplit_(KSplit), + underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) + { + } + + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == 1); + + return underlying_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] % CalculateGridSize())); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __device__ constexpr index_t CalculateGridSize() const + { + return CalculateGridSize(c_grid_desc_m_n_); + } + + __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + index_t KSplit) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KSplit), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; + } + + CGridDesc_M_N c_grid_desc_m_n_; + index_t M01_, N01_, KSplit_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); + UnderlyingMap underlying_map_; +}; + +template +__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) +{ + bool is_valid = false; + + const index_t m_block = c_tile_dim[Number<0>{}]; + const index_t n_block = c_tile_dim[Number<1>{}]; + + if constexpr(CTileIdx::Size() == 2) + { + const index_t m_block_idx = c_tile_idx[Number<0>{}]; + const index_t n_block_idx = c_tile_idx[Number<1>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + } + else if constexpr(CTileIdx::Size() == 3) + { + const index_t ksplit_idx = c_tile_idx[Number<0>{}]; + const index_t m_block_idx = c_tile_idx[Number<1>{}]; + const index_t n_block_idx = c_tile_idx[Number<2>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + ignore = ksplit_idx; + } + + return is_valid; +} + +// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the +// workgroups assigned to a given gemm problem have top index offsetted to range [0, +// grid_size_per_gemm] +template +struct OffsettedBlockToCTileMap +{ + using underlying_type = UnderlyingBlockToCTileMap; + + OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bdebe3816f22d6a4f10e913210abc8c8347e20d0 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M_Tuple out_grid_desc_m_tuple, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple, + index_t block_group_size, + index_t num_k_block_tile_iteration, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) +{ + GridwiseMultipleReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_tuple, + in_elementwise_op_tuple, + acc_elementwise_op_tuple, + block_group_size, + num_k_block_tile_iteration, + alpha_values, + p_in_value_global, + beta_values, + p_out_value_global_tuple); +}; + +template +struct GridwiseMultipleReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypePointerTuple::Size() && + NumReduction == OutGridDesc_M_Tuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M_Tuple& out_grid_desc_m_tuple, + const InElementwiseOperationTuple& in_elementwise_op_tuple, + const AccElementwiseOperationTuple& acc_elementwise_op_tuple, + index_t block_group_size, + index_t num_k_block_tile_iteration, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS, reused by all reductions + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf_tuple = generate_tuple( + [&](auto iR) { + return make_dynamic_buffer( + p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize()); + }, + Number{}); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + auto in_thread_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + auto accu_value_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}( + [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; }); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, NumReduction, 1>{}([&](auto iR) { + using OutDataTypePointer = remove_cvref_t; + using OutDataType = remove_cvref_t>; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf_tuple(iR)(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I), + accu_value_buf_tuple(iR)(I)); + + accu_value_buf_tuple(iR)(I) *= alpha_values[iR]; + } + }); + + if(thread_k_cluster_id == 0) + { + if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR])) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + 1, + false>( + out_grid_desc_m_tuple[iR], + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR), + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf_tuple(iR)(I) += + type_convert(priorDstValueBuf[I]) * beta_values[iR]; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m_tuple[iR], + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf_tuple[iR], + out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR)); + }; + }); + }; +}; // namespace ck + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1313ec9435e782ad9cc9b92580464f6d5257b93e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M_Tuple out_grid_desc_m_tuple, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) +{ + GridwiseMultipleReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_tuple, + in_elementwise_op_tuple, + acc_elementwise_op_tuple, + alpha_values, + p_in_value_global, + beta_values, + p_out_value_global_tuple); +}; + +template +struct GridwiseMultipleReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypePointerTuple::Size() && + NumReduction == OutGridDesc_M_Tuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M_Tuple& out_grid_desc_m_tuple, + const InElementwiseOperationTuple& in_elementwise_op_tuple, + const AccElementwiseOperationTuple& acc_elementwise_op_tuple, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf_tuple = generate_tuple( + [&](auto iR) { + return make_dynamic_buffer( + p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize()); + }, + Number{}); + + StaticBuffer + in_thread_buf; + + auto in_thread_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + auto accu_value_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}( + [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; }); + }); + + const index_t thread_global_1d_id = get_thread_global_1d_id(); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, NumReduction, 1>{}([&](auto iR) { + using OutDataTypePointer = remove_cvref_t; + using OutDataType = remove_cvref_t>; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I), + accu_value_buf_tuple(iR)(I)); + + accu_value_buf_tuple(iR)(I) *= alpha_values[iR]; + }); + + if(!float_equal_zero{}(beta_values[iR])) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + 1, + false>( + out_grid_desc_m_tuple[iR], + make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR), + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf_tuple(iR)(I) += + type_convert(priorDstValueBuf[I]) * beta_values[iR]; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m_tuple[iR], + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf_tuple[iR], + out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR)); + }); + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6836a66047531d9eef09dbffc58188f7696a783a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + (void)p_in_index_global; + (void)p_out_index_global; + + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(block_group_size == 0 && !float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + false>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + } + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using BlockwiseReduceWithIndex = + PartitionedBlockwiseReductionWithIndex, + ThreadClusterArrangeOrder, + ReduceOperation, + PropagateNan>; + + using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; + + (void)in_elementwise_op; + + // LDS + __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; + __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_val_buf = + make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); + auto reduce_work_idx_buf = + make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_1d_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + index_t indexOffset = 0; + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + // initialize the indices for the per-thread to-reduce values + in_thread_idx_buf(Number{}) = + indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); + + // do element-wise pre-reduction operation + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexOffset += K_BlockTileSize; + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + }; + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + // for indiced operation, acc_elementwise_op shoud do nothing + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + threadwise_dst_idx_store.Run(reduced_data_desc, + make_tuple(I0), + accu_index_buf, + out_grid_desc_m, + out_global_idx_buf); + } + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c5bd29f9b5ee70b995342f04d59b8ed042b1194 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -0,0 +1,474 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + true>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + dst_global_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex; + + (void)acc_elementwise_op; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t indexStart = 0; + index_t reducedLength = 0; + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + } + else + { + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_thread_idx_buf(Number{}) = indexStart + iK(); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + }; + + // for indiced operation, acc_elementwise_op shoud do nothing + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + false>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf); + + threadwise_dst_idx_store.Run( + reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf); + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fccb127d0f342911ed6bb1f8396c9d9cd0eec921 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,931 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + decltype(acc_element_op), + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{acc_element_op}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + false, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + c_thread_buf.Clear(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for gemm0 LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); + } + } // end gemm1 + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9f4a3080a0e17ca0ba169bc8b83c006acb2305d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -0,0 +1,1268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto WaveSize = 64; + // K1 should be Number<...> + // Gemm0 + static constexpr auto A0K1 = Number{}; + static constexpr auto B0K1 = Number{}; + + static constexpr auto A0K0PerBlock = Number{}; + static constexpr auto B0K0PerBlock = Number{}; + + static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave); + static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave); + // Gemm1 + static constexpr auto B1K1 = Number{}; + static constexpr auto B1K0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + // ck::Tuple + static constexpr auto MakeD0sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeD1sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D1DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __device__ static auto GetGemm0WaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id) + { + constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(WaveSize / Gemm0NPerXdl, Gemm0NPerXdl))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + A0BlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = Gemm0NPerBlock / (Gemm0NXdlPerWave * Gemm0NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + A0BlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A0 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(A0K0PerBlock, Number{}, A0K1), + make_tuple(Number{} * A0K1, A0K1, I1)); + } + + __host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B0 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B0K0PerBlock, Number{}, B0K1), + make_tuple(Number{} * B0K1, B0K1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0PerBlock, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + + constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned + + SharedMemTrait::b0_block_space_size_aligned) * + sizeof(A0B0B1DataType); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(A0B0B1DataType); + const index_t c1_block_bytes_end = + SharedMemTrait::c1_block_space_size * sizeof(C1ShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, + const B0GridDesc_N_K& b0_grid_desc_n_k, + const B1GridDesc_N_K& b1_grid_desc_n_k, + const E1GridDesc_M_N& e1_grid_desc_m_n, + const Block2E1TileMap& block_2_e1tile_map) + { + static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) && + (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a0_grid_desc_m_k.GetLength(I0); + const auto N = b0_grid_desc_n_k.GetLength(I0); + const auto K = a0_grid_desc_m_k.GetLength(I1); + const auto Gemm1N = b1_grid_desc_n_k.GetLength(I0); + + if(!(M == e1_grid_desc_m_n.GetLength(I0) && Gemm1N == e1_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % Gemm0MPerBlock == 0 && N % Gemm0NPerBlock == 0 && K % Gemm0KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / Gemm0KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(Gemm0NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / Gemm0KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + // A0 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K& a0_grid_desc_m_k) + { + const auto M = a0_grid_desc_m_k.GetLength(I0); + const auto K = a0_grid_desc_m_k.GetLength(I1); + + const auto A0K0 = K / A0K1; + + return transform_tensor_descriptor( + a0_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(A0K0, A0K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B0 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K& b0_grid_desc_n_k) + { + const auto N = b0_grid_desc_n_k.GetLength(I0); + const auto K = b0_grid_desc_n_k.GetLength(I1); + + const auto B0K0 = K / B0K1; + + return transform_tensor_descriptor( + b0_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B0K0, B0K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // D0 desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n) + { + const auto M = d0_grid_desc_m_n.GetLength(I0); + const auto N = d0_grid_desc_m_n.GetLength(I1); + + constexpr auto mfma = + MfmaSelector::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; + return transform_tensor_descriptor( + d0_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + M / Gemm0MPerBlock, Gemm0MXdlPerWave, Gemm0MWaves, Gemm0MPerXdl)), + make_unmerge_transform(make_tuple(N / Gemm0NPerBlock, + Gemm0NXdlPerWave, + Gemm0NWaves, + N3, + WaveSize / Gemm0NPerXdl, + N5))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); + } + + // B1 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // C1 desc for destination in blockwise copy + __host__ __device__ static constexpr auto + MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc_M_N& e1_grid_desc_m_n) + { + const auto M = e1_grid_desc_m_n.GetLength(I0); + const auto N = e1_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / Gemm0MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e1_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e1_grid_desc_mblock_mperblock_nblock_nperblock; + } + // D0s desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]); + }, + Number{}); + } + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to C1 matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2E1TileMap(const E1GridDesc_M_N& e1_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e1_grid_desc_m_n); + } + + using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + + using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2E1TileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A0 and B0: be careful of alignment + static constexpr auto a0_block_desc_ak0_m_ak1 = + GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b0_block_desc_bk0_n_bk1 = + GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1); + + static constexpr auto a0_block_space_size_aligned = math::integer_least_multiple( + a0_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( + b0_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a0_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a0_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for C1 shuffle in LDS + static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c1_block_space_size = + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D1sGridPointer = decltype(MakeD1sGridPointer()); + + template + __device__ static void Run(const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sGridPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sGridPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + void* __restrict__ p_shared, + const A0ElementwiseOperation& a0_element_op, + const B0ElementwiseOperation& b0_element_op, + const CDE0ElementwiseOperation& cde0_element_op, + const B1ElementwiseOperation& b1_element_op, + const CDE1ElementwiseOperation& cde1_element_op, + const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap& block_2_e1tile_map) + { + const auto a0_grid_buf = make_dynamic_buffer( + p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto e1_grid_buf = make_dynamic_buffer( + p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const auto d0s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d0s_grid[i], + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize()); + }, + Number{}); + const auto d1s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d1s_grid[i], + d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_e1tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_e1tile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * Gemm0MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A0 matrix in LDS memory, dst of blockwise copy + constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B0 matrix in LDS memory, dst of blockwise copy + constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A0 matrix blockwise copy + auto a0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(a0_grid_desc_ak0_m_ak1), + decltype(a0_block_desc_ak0_m_ak1), + A0BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + A0BlockTransferSrcVectorDim, + 2, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemm0KPrefetchStage>( + a0_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a0_element_op, + a0_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B0 matrix blockwise copy + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_BK0_N_BK1, + B0BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(b0_grid_desc_bk0_n_bk1), + decltype(b0_block_desc_bk0_n_bk1), + B0BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemm0KPrefetchStage>( + b0_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b0_element_op, + b0_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr index_t KPack = math::max( + math::lcm(A0K1, B0K1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< + BlockSize, + A0B0B1DataType, + Acc0DataType, + decltype(a0_block_desc_ak0_m_ak1), + decltype(b0_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a0_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b0_block_desc_bk0_n_bk1)), + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0KPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + // LDS allocation for A0 and B0: be careful of alignment + auto a0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a0_block_space_offset, + a0_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + b0_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / A0K1, 0, 0); + constexpr auto b0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / B0K1, 0, 0); + const auto a0_block_reset_copy_step = + make_multi_index(-a0_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b0_block_reset_copy_step = + make_multi_index(-b0_grid_desc_bk0_n_bk1.GetLength(I0), Gemm0NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm0_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a0_grid_desc_ak0_m_ak1.GetLength(I0) * a0_grid_desc_ak0_m_ak1.GetLength(I2)) / + Gemm0KPerBlock); + + // + // set up Gemm1 + // + + // Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm0.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // d0 matrix threadwise copy + constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId + I1, // NBlockID + I1, // MRepeat + I1, // NRepeat + I1, // MWaveId + I1, // NWaveId + I1, // MPerXdl + I1, // NGroupNum + I1, // NInputNum + n4)); // registerNum + + auto d0s_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer< + AddressSpaceEnum::Vgpr, + A0B0B1DataType, + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), + true>{}; + }, + Number{}); + + const auto wave_id = GetGemm0WaveIdx(); + const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 + + constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, n2, n4)); + + auto d0s_threadwise_copy = generate_tuple( + [&](auto i) { + return ThreadwiseTensorSliceTransfer_v2< + A0B0B1DataType, + A0B0B1DataType, + decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), + decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + n4, + 1, + false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(block_work_idx[I0], // MBlockId + 0, // NBlockId + 0, // mrepeat + 0, // nrepeat + wave_id[I0], // MWaveId + wave_id[I1], // NWaveId + wave_m_n_id[I1], // MPerXdl + 0, // group + wave_m_n_id[I0], // NInputIndex + 0)); // register number + }, + Number{}); + // acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc0_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto Acc0N3 = + blockwise_gemm0.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple( + Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + A0B0B1DataType, + decltype(acc0_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + 1>(b1_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b0_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr index_t Gemm1KPack = math::max( + math::lcm( + MfmaSelector::selected_mfma.group_size, + B1K1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< + BlockSize, + A0B0B1DataType, + Acc1DataType, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + false, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{} + .K0PerXdlops>{ // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const index_t num_gemm1_k_block_outer_loop = + b0_grid_desc_bk0_n_bk1.GetLength(I1) / Gemm0NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock; + + // Initialize C1 + c1_thread_buf.Clear(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + // gemm0 + gridwise_gemm0_pipeline.template Run(a0_grid_desc_ak0_m_ak1, + a0_block_desc_ak0_m_ak1, + a0_blockwise_copy, + a0_grid_buf, + a0_block_buf, + a0_block_slice_copy_step, + b0_grid_desc_bk0_n_bk1, + b0_block_desc_bk0_n_bk1, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + num_k_block_main_loop); + // bias+gelu + { + static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) { + static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) { + static_for<0, n2, 1>{}([&](auto groupid) { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + + static_for<0, n4, 1>{}([&](auto i) { + constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( + make_tuple(mr, nr, groupid, i)); + + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + return d0s_thread_buf[iSrc][i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { + return acc0_thread_buf(Number{}); + }, + Number<2>{}); + + unpack2(cde0_element_op, dst_data_refs, src_data_refs); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + }); + } + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for gemm0 LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc0_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf); + } + } // end gemm1 + + a0_blockwise_copy.MoveSrcSliceWindow(a0_grid_desc_ak0_m_ak1, + a0_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_bk0_n_bk1, + b0_block_reset_copy_step); // rewind K and step N + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C1 and write out + { + static_assert(Gemm0MXdlPerWave % C1ShuffleGemm0MXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % C1ShuffleGemm0NXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm1.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm1.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c1_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (Gemm0MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = Gemm0MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (Gemm0NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = Gemm0NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM C1 matrix starting index + const auto c1_thread_mtx_on_block = + blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c1_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c1_d1s_desc_refs = concat_tuple_of_reference( + tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c1_d1s_buf_refs = concat_tuple_of_reference( + tie(c1_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c1_d1s_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})), + Tuple, + decltype(c1_d1s_desc_refs), + decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)), + CDE1ElementwiseOperation, + Sequence(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray + // type + Sequence<1, + C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, + 1, + C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * + Gemm0NPerXdl>, // BlockSliceLengths, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c1_d1s_desc_refs, + idx_c1_d1s_block_begin, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde1_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c1_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_e1_global = SpaceFillingCurve< + Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, + 1, + C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{}; + + constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c1_vgpr.GetIndexTupleOfNumber(access_id), + c1_thread_buf, + c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c1_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde1_shuffle_block_copy_lds_to_global.Run( + c1_d1s_desc_refs, + c1_d1s_buf_refs, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e1_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id); + + // move on D1s + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c1_d1s_desc_refs, i + I1, e1_global_step); + }); + + // move on C + cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7f4001a59ab497ec4120aad9bf1f2e07c91156f9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); + static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); + + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned) * + sizeof(FloatGemmAcc); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map, + const C0MatrixMask& c0_matrix_mask) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t gemm1_n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + true, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + // + // Blockwise softmax + // + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + + // get acc0 8D thread cluster + constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); + constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); + constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); + constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); + constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); + constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); + constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); + constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); + + // get acc0 thread map + constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_m_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); + constexpr auto thread_slice_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + StaticBuffer + c_thread_buf; + c_thread_buf.Clear(); + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + auto n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) + { + continue; + } + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 8d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + // 8d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + constexpr auto M0 = c_block_lengths[I0]; + constexpr auto N0 = c_block_lengths[I1]; + constexpr auto M1 = c_block_lengths[I2]; + constexpr auto N1 = c_block_lengths[I3]; + constexpr auto M2 = c_block_lengths[I4]; + constexpr auto N2 = c_block_lengths[I5]; + constexpr auto N3 = c_block_lengths[I6]; + constexpr auto N4 = c_block_lengths[I7]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( + Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), + make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto n_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto n_global = n_local + n_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) + { + acc_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); + } + }); + } + else + { + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + } + + block_sync_lds(); // wait for lds read in gemm0 blockwise gemm + + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + // workaround compiler issue; see ck/ck.hpp + if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && + is_same_v && MPerBlock == 256 && NPerBlock == 128 && + Gemm1NPerBlock == 128) + { + __builtin_amdgcn_s_barrier(); + } + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ede6a96dc9f53045b5fc029f4c4f4f6aa4fa4817 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_batchnorm_backward_with_blockwise_welford( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + long_index_t reduce_size, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) +{ + GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + dx_grid_desc_m_k, + scale_grid_desc_m, + dscale_dbias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + reduce_size, + num_k_block_tile_iteration, + epsilon, + p_x, + p_dy, + p_scale, + haveSavedMeanInvVar, + p_savedMean, + p_savedInvVar, + dy_elementwise_op, + p_dx, + p_dscale, + p_dbias); +}; + +template +struct GridwiseBatchNormBackwardWithBlockwiseWelford +{ + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Blockwise BatchNorm Backward + // Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size + // Output: dx, dscale, dbias + // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) + // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance) + // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + __device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + long_index_t reduce_size, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) + { + using ck::math::sqrt; + + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + x_thread_buf; + + StaticBuffer + dy_thread_buf; + + StaticBuffer + dx_thread_buf; + + // buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction + StaticBuffer + tmp1_thread_buf; + + StaticBuffer scale_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + StaticBuffer& + inv_var_thread_buf = var_thread_buf; + + StaticBuffer dscale_thread_buf; + StaticBuffer dbias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_dscale_dbias_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DscaleDbiasDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_buf = make_dynamic_buffer( + p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + auto dscale_global_buf = make_dynamic_buffer( + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + auto dbias_global_buf = make_dynamic_buffer( + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + // clang-format off + // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) + // clang-format on + + if(haveSavedMeanInvVar) + { + const auto mean_global_buf = make_dynamic_buffer( + p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_inv_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + } + else + { + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + // calculate inv-variance as 1/sqrt(epsilon+variance) + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + inv_var_thread_buf(I) = + type_convert(1.0) / sqrt(var_thread_buf[I] + epsilon); + }); + + threadwise_x_load.SetSrcSliceOrigin( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + }; + + // clang-format off + // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance) + // clang-format on + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dscale_thread_buf(I) = type_convert(0); + dbias_thread_buf(I) = type_convert(0); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dx_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; + }); + }); + + ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf); + ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); + + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); + }; + + // clang-format off + // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + AccDataType inv_reduce_size = + type_convert(1.0) / type_convert(reduce_size); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM]; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + AccDataType tmpVal = norm_x * dscale_thread_buf[iM]; + + dx_thread_buf(Number{}) = + multiplier * + (type_convert(reduce_size) * dy_thread_buf[Number{}] - + dbias_thread_buf[iM] - tmpVal); + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33c45a0f037b4c81fbaf810cc10e55f71a2ab254 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_batchnorm_forward_with_blockwise_welford( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseBatchrNormForwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseBatchNormForwardWithBlockwiseWelford +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + StaticBuffer + x_thread_buf; + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: do welford reduction to get mean and variance + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + // Step 2: do normalization and output y + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2369f51795d9a3bcf5a28ad9de702d75e24fac48 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp @@ -0,0 +1,662 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP +#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) +{ + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // GM0 and GN0 need to known at compile-time + static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); + static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); + static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! GM0 and GN0 need to be known at compile-time"); + + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return ( + (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && + GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && + GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && + GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && + GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && + GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && + GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && + GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && + GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && + GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr index_t GM11 = GM1PerBlockGM11; + constexpr index_t GN11 = GN1PerBlockGN11; + + const index_t GM10 = GM1 / GM11; + const index_t GN10 = GN1 / GN11; + + const index_t grid_size = GM10 * GN10; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) + { + const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) + { + const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) + { + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + + const auto GM11 = Number{}; + const auto GM10 = GM1 / GM11; + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor( + a_grid_desc_gk0_gm0_gm1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_grid_desc_gk0_gm0_gm10_gm11_gk1; + } + + __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) + { + const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + + const auto GN11 = Number{}; + const auto GN10 = GN1 / GN11; + + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor( + b_grid_desc_gk0_gn0_gn1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return b_grid_desc_gk0_gn0_gn10_gn11_gk1; + } + + __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + constexpr auto BM = GM0 * GM11; + constexpr auto BN = GN0 * GN11; + + constexpr auto BM1 = + Number{}; + constexpr auto BN1 = + Number{}; + + constexpr auto BM0 = BM / BM1; + constexpr auto BN0 = BN / BN1; + + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor( + c_grid_desc_gm0_gm1_gn0_gn1, + make_tuple(make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor( + c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_merge_transform(make_tuple(GM0, GM11)), + make_pass_through_transform(GN10), + make_merge_transform(make_tuple(GN0, GN11))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor( + c_gm10_bm_gn10_bn_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_unmerge_transform(make_tuple(BM0, BM1)), + make_pass_through_transform(GN10), + make_unmerge_transform(make_tuple(BN0, BN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; + } + + __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(GM10, GN10))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_grid_block_cluster_blockid_to_gm10_gn10; + } + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); + + const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); + + // divide block work by [GM10, GN10] + const auto c_gm10_gn10_block_cluster_idx = + c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); + const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); + + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // A matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); + + // B matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); + + static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == + a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == + b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), + decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, igm10, 0, 0), + a_block_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), + decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, ign10, 0, 0), + b_block_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS + // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS + // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_block_desc_gk0_bm_gk1), + decltype(b_block_desc_gk0_bn_gk1), + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + BM1PerThreadBM11, + BN1PerThreadBN11>{}; + + constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); + + ThreadwiseTensorSliceSet_v1{} + .Run(c_thread_desc_bm0_bm1_bn0_bn1, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t gk0_block_on_grid = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + + gk0_block_on_grid += 2 * GK0PerBlock; + } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), + decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), + Sequence<1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], + 1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_multi_index(igm10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], + ign10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} + .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_buf, + CGridStepHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b82b65540d1c5612002549f72708be07b649dae --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op) +{ + GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op); +} + +template +struct GridwiseElementwise_1D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid1dDescTuple::Size() && + NumOutput == OutGrid1dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op) + { + const index_t thread_global_id = get_thread_global_1d_id(); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At( + I), // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc_tuple[I], + thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + false>( + out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter = M / (loop_step); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I], + loop_step_index); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_m, + make_tuple(I0), + out_thread_buf_tuple[I], + out_grid_1d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I], + loop_step_index); + }); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..05257d16275e5f9817ec5d9cbbd06b133f0b5e05 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MIT +// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const OutGrid2dDescTuple out_grid_2d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n) +{ + GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple, + out_grid_2d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + num_threads_m, + num_threads_n); +} + +template +struct GridwiseElementwise_2D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid2dDescTuple::Size() && + NumOutput == OutGrid2dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr auto thread_buffer_desc_mn = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const OutGrid2dDescTuple out_grid_2d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n) + { + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0); + const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1); + + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + + const index_t thread_1d_id = get_thread_global_1d_id(); + index_t tid_m = thread_1d_id / num_threads_n; + index_t tid_n = thread_1d_id % num_threads_n; + + const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2< + DataType, + DataType, + decltype(in_grid_2d_desc_tuple[I]), + decltype(thread_buffer_desc_mn), + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At(I), // ScalarPerVector + 1, // SrcScalarStrideInVector + true>{in_grid_2d_desc_tuple[I], thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3< + DataType, + DataType, + decltype(thread_buffer_desc_mn), + decltype(out_grid_2d_desc_tuple[I]), + PassThroughOp, + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter_m = M / (loop_step_m); + do + { + index_t num_iter_n = N / (loop_step_n); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_mn, + make_tuple(I0, I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], + make_multi_index(0, loop_step_n)); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + static_for<0, NPerThread, 1>{}([&](auto iN) { + constexpr auto offset = + thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN)); + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(I)(Number{}); + }, + Number{}); + + // get referenec to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { + return out_thread_buf_tuple(I)(Number{}); + }, + Number{}); + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_mn, + make_tuple(I0, I0), + out_thread_buf_tuple[I], + out_grid_2d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I], + make_multi_index(0, loop_step_n)); + }); + + } while(--num_iter_n); + + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).MoveSrcSliceWindow( + in_grid_2d_desc_tuple[I], + make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).MoveDstSliceWindow( + out_grid_2d_desc_tuple[I], + make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); + }); + } while(--num_iter_m); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..989f7f1c62d4977cdf968f5f08f3645645da7fbc --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// X = Elementwise(input1, input2, input3, ...) +// Y = Normalization(X, beta, gamma) +template +struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto XThreadBufferNumber = Number{}; + static constexpr auto GammaThreadBufferNumber = Number{}; + static constexpr auto BetaThreadBufferNumber = Number{}; + static constexpr auto YThreadBufferNumber = Number{}; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + + __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const InDataTypePointerTuple p_in_global_tuple, + XDataType* const __restrict__ p_x_lds, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const XElementwiseOperation x_elementwise_op, + const YElementwiseOperation y_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t grid_size = get_grid_size(); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() == + 2); // matrix dimension + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto x_lds_val_buf = make_dynamic_buffer( + p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto) { + return generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + }, + Number{}); + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto beta_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2{ + in_grid_2d_desc_tuple[I], + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)}; + }, + Number{}); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + using PassThrough = tensor_operation::element_wise::PassThrough; + PassThrough pass_through_op; + auto threadwise_x_store = + ThreadwiseTensorSliceTransfer_v1r3( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize), + pass_through_op); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + // Copy x from Cache + // one pass: fwd, second pass: bwd + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, NumInput, 1>{}([&](auto I) { // input load loop + in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m_k, + make_tuple(I0, I0), + in_thread_buf_tuple(iK0)(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(iK0)(I)(Number{}); + }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return x_thread_buf(iK0)(Number{}); }, + I1); + + unpack2(x_elementwise_op, out_data_refs, in_data_refs); + }); + }); + threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); + + if constexpr(!SweepOnce) + { + threadwise_x_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(iK0), + x_grid_desc_m_k, + x_lds_val_buf); + threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + if constexpr(!SweepOnce) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_lds_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } + + static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..16ba23280ddcdf9c3f1d87cb2e4cd64a72691650 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,997 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_bias_grid; + ignore = p_d0_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c1_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const C1ElementwiseOperation& c1_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatReduceAcc, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC1, + FloatReduceAcc, + decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatC, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_buf, + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_buf, + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c1_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C1 + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // Reduction + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..25bb714a797e8ce2fb7ab83577251df5a731ff7a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -0,0 +1,685 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemmDlMultipleD_km_kn_mn +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + bool ret=(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + if(!ret){ + std::cout<<"M="<, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize()); + }, + Number{}); + + auto ds_thread_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto ds_threadwise_copy = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + DDataType, + decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]), + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + Sequence{}>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + 1, + false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])); + }, + Number{}); + + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) { + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) { + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) { + // load d matrix data + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + ds_grid_buf[i], + c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + ds_thread_buf(i)); + }); + // cal element op + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}( + [&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + return ds_thread_buf[iSrc][i]; + }, + Number{}); + + // get reference to dst data + constexpr index_t c_offset = + c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset( + make_tuple(0, m10, m11, 0, n10, i)); + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return c_thread_buf(Number{}); }, + Number<2>{}); + + unpack2(cde_element_op, dst_data_refs, src_data_refs); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index(0, 0, 0, 0, 1, 0)); + }); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index( + 0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0)); + }); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index( + 0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0)); + }); + }); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c839cde0f70aa7c23cf7fe80e688294cbc11a445 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDl_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) + { + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_grid_desc_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_grid_desc_k0_n0_n1_k1; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); + } + + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..84e033e1e91bb9873ebed66906c6da5a287bf1bd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp @@ -0,0 +1,608 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP +#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r2.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AKM0M1GridDesc a_k_m0_m1_grid_desc, + const BKN0N1GridDesc b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + cblockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDlops_km_kn_mn_v1r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K == b_k_n_grid_desc.GetLength(I0)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) + { + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) + { + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor( + a_k_m_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) + { + const auto K = b_k_n_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor( + b_k_n_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); + using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AKM0M1GridDesc& a_k_m0_m1_grid_desc, + const BKN0N1GridDesc& b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + const auto K = a_k_m0_m1_grid_desc.GetLength(I0); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M0_M1, + ABlockTransferThreadClusterLengths_K_M0_M1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m0_m1_grid_desc), + decltype(a_k_m0_m1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k_m0_m1_grid_desc, + make_multi_index(0, im0, 0), + a_k_m0_m1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N0_N1, + BBlockTransferThreadClusterLengths_K_N0_N1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n0_n1_grid_desc), + decltype(b_k_n0_n1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k_n0_n1_grid_desc, + make_multi_index(0, in0, 0), + b_k_n0_n1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; + constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = + AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = + BGridMoveSliceWindowStepHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k_n0_n1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridStepHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b1dfb0c73fc92a19662e01df31090008cdef87ff --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_GRIDWISE_GEMM_V2_HPP +#define CK_GRIDWISE_GEMM_V2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct GridwiseGemmDlops_km_kn_mn_v3 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto E = EPerBlock * 3 * 3; + + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e_k_global_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); + + constexpr auto E = EPerBlock * 3 * 3; + + // const auto E = a_e_k_global_desc.GetLength(I0); + const auto K = a_e_k_global_desc.GetLength(I1); + + const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); + const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); + const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); + +// divide block work by [M, N] +#if 0 + const auto ho_block_work_num = Ho / Number{}; + const auto wo_block_work_num = Wo / Number{}; + const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#else + // Hack: this force result into SGPR + const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); + const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); + const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = + __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#endif + + // lds max alignment + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const auto k_thread_id = c_thread_mtx_index.k; + const auto ho_thread_id = c_thread_mtx_index.h; + const auto wo_thread_id = c_thread_mtx_index.w; + + const index_t k_block_data_on_global = k_block_work_id * KPerBlock; + const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; + const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; + + const index_t ho_thread_data_on_global = + ho_block_data_on_global + ho_thread_id * HoPerThread; + const index_t wo_thread_data_on_global = + wo_block_data_on_global + wo_thread_id * WoPerThread; + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e_k_global_desc), + decltype(a_e_k_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_e_k_global_desc, + make_multi_index(0, k_block_data_on_global), + a_e_k_desc, + make_multi_index(0, 0)); + + constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto b_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + 1, + true>( + b_e_n_ho_wo_global_desc, + make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e_k_desc.GetElementSpaceSize()); + + // register allocation for output + StaticBuffer + c_thread_buf; + + // initialize output thread tensor + ThreadwiseTensorSliceSet_v1>{} + .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + BGlobalMoveSliceWindowStepHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainKBlockLoop) + { + index_t e_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on current data + // TODO: @Zhang Jing: blockwise gemm should be able to move slice window + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + e_block_data_begin += 2 * EPerBlock; + + } while(e_block_data_begin < E - 2 * EPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + // output: register to global memory + { + // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; + + const index_t k_thread_data_on_global = + k_block_data_on_global + k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>( + c_k_n_ho_wo_global_desc, + make_multi_index( + k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) + .Run(c_k_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + c_k_n_ho_wo_global_desc, + c_global_buf, + c_k_n_ho_wo_global_tensor_step_hacks); + } + } + + // pass tensor descriptor by reference + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + p_shared_block, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by their pointers + template + __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *p_a_e_k_global_desc; + const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; + const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by void* + template + __device__ void Run(const void* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const void* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const void* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); + const auto b_e_n_ho_wo_global_desc = + *reinterpret_cast(p_b_e_n_ho_wo_global_desc); + const auto c_k_n_ho_wo_global_desc = + *reinterpret_cast(p_c_k_n_ho_wo_global_desc); + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ace844338414668030e5dd6aeeaf559e1a7c1060 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp @@ -0,0 +1,1597 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_GRIDWISE_GEMM_V3_HPP +#define CK_GRIDWISE_GEMM_V3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActiv(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_resize_add( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid, + p_b_grid, + p_bias_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v3_maxpool( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatC* __restrict__ p_bias_grid, + FloatC* __restrict__ p_c_grid, + FloatC* __restrict__ p_d_grid, + const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::ConvBiasActivMaxpool(p_a_grid, + p_b_grid, + p_bias_grid, + p_c_grid, + p_d_grid, + p_shared_block, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDlops_km_kn_mn_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto E1 = Number{}; + static constexpr auto E2 = Number{}; + static constexpr auto K2 = Number{}; + + static constexpr auto NPerBlock = I1; + + static constexpr FloatAcc alpha = 0.3; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(I1, Number{}, Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; + + const index_t grid_size = K0 * N0 * H0 * W0; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0) + { + const bool has_main_e0_block_loop = E0 > 1; + + return has_main_e0_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop() + { + const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1; + + return has_main_e1_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop() + { + const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0; + + return has_double_tail_e1_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc) + { + const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0); + const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor( + a_e0_e1_k_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_e0_e1_k0_k1_e2_grid_desc; + } + + __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor( + const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) + { + const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); + // const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1); + const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2); + const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); + const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); + // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = + transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, + make_tuple(make_pass_through_transform(E0), + make_pass_through_transform(E1), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2)), + make_pass_through_transform(E2)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3, 4, 5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{})); + + return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Ho / (H1 * H2); + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Wo / (W1 * W2); + + const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor( + c_k_n_ho_wo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H2 = Number{}; + const auto H1 = Number{}; + const auto H0 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + const auto W0 = Number{}; +#else + const auto H2 = HoPerThread / 2; + const auto H1 = HoPerBlock / HoPerThread; + const auto H0 = Hx / (H1 * H2); + + const auto W2 = WoPerThread / 2; + const auto W1 = WoPerBlock / WoPerThread; + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) + { + const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); + const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); + const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); + const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); + + const auto K1 = Number{}; + const auto K0 = K / K1; + + const auto H2 = Number{}; + const auto H1 = Number{}; + + const auto W2 = Number{}; + const auto W1 = Number{}; + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto H0 = Hx / (H1 * H2); + const auto W0 = Wx / (W1 * W2); +#endif + + const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( + d_k_n_hx_wx_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_unmerge_transform(make_tuple(H0, H1, H2)), + make_unmerge_transform(make_tuple(W0, W1, W2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); + + return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) + { + const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); + const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); + const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); + const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); + +#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR + const auto K0 = Number{}; + const auto N0 = Number{}; + const auto H0 = Number{}; + const auto W0 = Number{}; +#else + const auto K0 = K / KPerBlock; + const auto N0 = N / NPerBlock; + const auto H0 = Ho / HoPerBlock; + const auto W0 = Wo / WoPerBlock; +#endif + + const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + return cblockid_to_k_n_ho_wo_block_cluster_adaptor; + } + + // using AGridDesc_E0_E1_K0_K1_E2 = + // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); + // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = + // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); + // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = + // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); + // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = + // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); + + using CBlockIdToBlockClusterAdaptor_K_N_H_W = + decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); + + template + __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0); + const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1); + + return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); + } + + __host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor() + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + return c_k1_n_h2_w2_thread_gemm_desc; + } + + // using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor()); + + __host__ __device__ static constexpr auto GetBlockWiseGemm() + { + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + return blockwise_gemm; + } + + __device__ static constexpr auto GetCThreadIndex() + { + auto blockwise_gemm = GetBlockWiseGemm(); + auto c_thread_mtx_index = + blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + return c_thread_mtx_index; + }; + + __device__ static constexpr auto GetCBlockIndex( + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor) + { + const auto c_k_n_h_w_block_cluster_idx = + cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + return c_k_n_h_w_block_cluster_idx; + } + + template + __device__ static void BiasOp(BiasGlobalBuff& bias_global_buf, + CThreadBuff& c_thread_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc, + const CThreadDesc_K1_N_H2_W2&) + + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + + const auto k_thread_id = c_thread_idx[I0]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto bias_k0_k1_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + StaticBuffer + bias_thread_buf; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + auto bias_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2{}>, + Sequence<0, 1>, + 1, + CThreadTransferDstScalarPerVector, + false, + true>( + bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global)); + + constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple( + make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{})); + + bias_threadwise_transfer.Run(bias_k0_k1_grid_desc, + bias_global_buf, + bias_k0_k1_thread_desc, + make_tuple(I0, I0), + bias_thread_buf, + bias_k0_k1_global_tensor_step_hacks); + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread, 1>{}([&](auto hi) { + static_for<0, WoPerThread, 1>{}([&](auto wi) { + constexpr index_t c_offset = + c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi)); + c_thread_buf(Number{}) = + c_thread_buf[Number{}] + bias_thread_buf[ki]; + }); + }); + }); + } + + template + __device__ static void Activation(CThreadBuff& c_thread_buf, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { + if constexpr(activ_type_ == 1) + { + c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; + } + else if constexpr(activ_type_ == 2) + { + FloatAcc x = 1.0 + exp(-c_thread_buf[i]); + + asm volatile("\n \ + v_rcp_f32 %0, %1 \n" + : "=v"(x) + : "0"(x)); + + c_thread_buf(i) = x; + } + }); + } + + template + __device__ static void + WriteOut(const CThreadBuff& c_thread_buf, + CGlobalBuff& c_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) + { + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global + // tensor + constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; + + constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), + decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + c_global_buf, + c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); + } + + template + __device__ static void + MaxPool(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, ""); + + constexpr auto HoPerThread_2 = HoPerThread / 2; + constexpr auto WoPerThread_2 = WoPerThread / 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto ki) { + static_for<0, HoPerThread_2, 1>{}([&](auto hi) { + static_for<0, WoPerThread_2, 1>{}([&](auto wi) { + constexpr index_t d_offset = + d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset( + make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi)); + + constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2)); + constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2, wi * 2 + 1)); + constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2)); + constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( + make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); + + d_thread_buf(Number{}) = c_thread_buf[Number{}]; + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); + d_thread_buf(Number{}) = + fmax(c_thread_buf[Number{}], d_thread_buf(Number{})); + }); + }); + }); + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + ResizeAdd(const CThreadBuff& c_thread_buf, + DGlobalBuff& d_global_buf, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const CThreadDesc_K1_N_H2_W2&, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) + { + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + const auto k_thread_id = c_thread_idx[I0]; + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + constexpr auto HoPerThreadx2 = HoPerThread * 2; + constexpr auto WoPerThreadx2 = WoPerThread * 2; + + constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{})); + + StaticBuffer + d_thread_buf; + + static_for<0, KPerThread, 1>{}([&](auto k_i) { + static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { + static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { + d_thread_buf(Number{}) = + c_thread_buf[Number{}]; + }); + }); + }); + + // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor + constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; + + const index_t k_thread_data_on_global = k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), + decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + InMemoryDataOperationEnum::Add, + 1, + true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + make_multi_index(k_block_work_id, + k_thread_data_on_global, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0)) + .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d_thread_buf, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + d_global_buf, + d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); + } + + template + __device__ static void + GemmOp(const AGlobalBuff& a_global_buf, + const BGlobalBuff& b_global_buf, + CThreadBuff& c_thread_buf, + FloatAB* __restrict__ p_shared_block, + const CBlockIndex& c_block_idx, + const CThreadIndex& c_thread_idx, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CThreadDesc_K1_N_H2_W2&, + integral_constant) + { + constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); + constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); + + // const auto c_k_n_h_w_block_cluster_idx = + // GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + // cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( + // make_multi_index(get_block_1d_id())); + + const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); + const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); + const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); + const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); + + constexpr auto max_lds_align = Number{}; + + constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + constexpr auto b_e1_n_h_w_e2_block_gemm_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + // blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); + + const auto ho_thread_id = c_thread_idx[I2]; + const auto wo_thread_id = c_thread_idx[I3]; + + constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, I1, Number{}, Number{}), + max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e0_e1_k0_k1_e2_grid_desc), + decltype(a_e0_e1_k0_k1_e2_block_copy_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorDim, + 4, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_E2, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + false>(a_e0_e1_k0_k1_e2_grid_desc, + make_multi_index(0, 0, k_block_work_id, 0, 0), + a_e0_e1_k0_k1_e2_block_copy_desc, + make_multi_index(0, 0, 0, 0, 0)); + + constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); + + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + I1, + I1, + I1, + Number{}, + I1, + I1, + Number{}, + Number{})); + + auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< + FloatAB, + FloatAB, + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), + decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), + Sequence, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + make_multi_index(0, + 0, + n_block_work_id, + ho_block_work_id, + ho_thread_id, + 0, + wo_block_work_id, + wo_thread_id, + 0, + 0)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); + + //// register allocation for output + // StaticBuffer + // c_thread_buf; + + // initialize output thread tensor + ThreadwiseTensorSliceSet_v1>{} + .Run(c_k1_n_h2_w2_thread_gemm_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = + make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + if constexpr(HasMainE0BlockLoop) + { + const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0); + + index_t e0_block_data_begin = 0; + + do + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, + a_block_slice_copy_step, + AGlobalMoveSliceWindowStepHacks{}); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + e0_block_data_begin += 1; + + } while(e0_block_data_begin < E0); + } + else + { + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead( + a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainE1BlockLoop) + { + index_t e1_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow( + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_even_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + e1_block_data_begin += 2 * E1PerBlock; + + } while(e1_block_data_begin < E1 - 2 * E1PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_thread_slice_copy_step, + BGlobalMoveSliceWindowStepHacks{}); + + b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + b_global_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_odd_buf, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + } + } + + template + __device__ static void + Conv(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant) + { + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActiv( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + } + + template + __device__ static void ConvBiasActivMaxpool( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_c_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Output + WriteOut(c_thread_buf, + c_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + // MaxPool + MaxPool(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } + + template + __device__ static void ConvBiasActivResizeAdd( + const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + const FloatC* __restrict__ p_bias_global, + FloatC* __restrict__ p_d_global, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, + const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, + const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, + const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, + integral_constant, + integral_constant) + { + static constexpr auto activ_type = integral_constant{}; + + const auto bias_k0_k1_grid_desc = + MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); + auto d_global_buf = make_dynamic_buffer( + p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); + auto bias_global_buf = make_dynamic_buffer( + p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); + + constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); + + // register allocation for output + StaticBuffer + c_thread_buf; + + const auto c_k_n_h_w_block_cluster_idx = + GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); + + const auto c_thread_mtx_index = GetCThreadIndex(); + + // GemmOp + GemmOp(a_global_buf, + b_global_buf, + c_thread_buf, + p_shared_block, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + a_e0_e1_k0_k1_e2_grid_desc, + b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc, + integral_constant{}); + + // Bias + BiasOp(bias_global_buf, + c_thread_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + bias_k0_k1_grid_desc, + c_k1_n_h2_w2_thread_gemm_desc); + + // Activ + Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); + + // Resize_Add + ResizeAdd(c_thread_buf, + d_global_buf, + c_k_n_h_w_block_cluster_idx, + c_thread_mtx_index, + c_k1_n_h2_w2_thread_gemm_desc, + d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..578665ea85fb723565e597bbec26457ffb449606 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,944 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +template +struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + template + static constexpr auto MakeTsGridPointer() + { + return generate_tuple( + [&](auto i) { + using T = remove_cvref_t>; + if constexpr(isConst) + return static_cast(nullptr); + else + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const EGridDesc_M_N& e_grid_desc_m_n, + const RGridDesc_M& r_grid_desc_m, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(AGridDesc_M_K::GetNumOfDimension() == 2); + static_assert(BGridDesc_N_K::GetNumOfDimension() == 2); + static_assert(EGridDesc_M_N::GetNumOfDimension() == 2); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + if(M != r_grid_desc_m.GetLength(I0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M& r_grid_desc_m) + { + const auto M = r_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto r_grid_desc_mblock_mperblock = transform_tensor_descriptor( + r_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return r_grid_desc_mblock_mperblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // Support 2 dimension in the future. Not only M + using RGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeTsGridPointer()); + using RsGridPointer = decltype(MakeTsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + RsGridPointer p_rs_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const StaticallyIndexedArray& + ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const StaticallyIndexedArray& + rs_grid_desc_mblock_mperblock, // FIXME: Rs desc may be of different + const Block2ETileMap& block_2_etile_map) + { + // FIXME - Share code with other gemm kernel + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto rs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_rs_grid(i), rs_grid_desc_mblock_mperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + Ds + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_der_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) * + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR cde_reduce_thread_desc_mperblock_nperblock + constexpr auto cde_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto r_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + constexpr auto r_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto e_thread_buf = make_static_buffer( + cde_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + // To apply D0, D1, ... and reduction. + // Copy c shuffle from LDS back to VGPR + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(cde_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + // Copy result of reduction back from VGPR to global + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_r_grid = p_rs_grid[I]; + auto r_element_op = rs_element_op[I]; + auto r_grid_desc_mblock_mperblock = rs_grid_desc_mblock_mperblock[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(r_thread_desc_mblock_mperblock), + decltype(r_grid_desc_mblock_mperblock), + decltype(r_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + RThreadTransferDstScalarPerVector_MPerBlock, + RsGlobalMemoryDataOperation::At(I), + 1, + false>{r_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + r_element_op}; + }, + Number{}); + + // D0, D1, ..., Dn + constexpr auto cde_reduce_thread_desc_I1_mperblock_I1_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // FIXME: Decrease usage of VGPR + // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR + auto ds_thread_buf = generate_tuple( + [&](auto) { + return make_static_buffer( + cde_reduce_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize()); + }, + Number{}); + + // Copy D0, D1, ..., Dn from global to VGPR + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + using DDataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + FloatReduceAcc, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + }, + Number{}); + + auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatE, + decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CDEReduceThreadTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to read from LDS + block_sync_lds(); + + // each thread shuffle data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + + // Get shuffle data from LDS to VGPR + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + cde_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + e_thread_buf); + + // Global read D0, D1, ... + static_for<0, NumDTensor, 1>{}([&](auto Id) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id); + d_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], + ds_grid_buf[Id], + cde_reduce_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + ds_thread_buf(Id)); + + if constexpr(access_id < num_access - 1) + { + // move on D0, D1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step); + } + }); + + // cde_element_op(e, c, d0, d1, ...); + static_for<0, cde_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + const auto c_ds_src_data_refs = concat_tuple_of_reference( + tie(e_thread_buf[i]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); + auto e_dst_data_refs = tie(e_thread_buf(i)); + unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); + }); + + // Global write E + e_thread_copy_vgpr_to_global.Run(cde_reduce_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + // move on E + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + e_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step); + } + + // reduction + static_for<0, NumRTensor, 1>{}([&](auto Ir) { + auto r_thread_buf = make_static_buffer( + r_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(Ir); + + using ThreadReduceOperation = + remove_cvref_t; + + using ThreadwiseReduce = + ThreadwiseReduction; + + // threadwise reduction + const auto reduce_identityVal = + ThreadReduceOperation::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { r_thread_buf(I) = reduce_identityVal; }); + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); + }); + }); + ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf); + + // gridwise reduction + reduce_thread_copy_vgpr_to_global.Run(r_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + r_thread_buf, + rs_grid_desc_mblock_mperblock[Ir], + rs_grid_buf(Ir)); + + if constexpr(access_id < num_access - 1) + { + // move on R0, R1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + rs_grid_desc_mblock_mperblock[Ir], + make_tuple(de_global_step[I0], de_global_step[I1])); + } + }); + }); // copy c, d, e + reduction + + } // shuffle C + Ds + reduction + write out + } // Run +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..da0b0cea241d7b62dd8e138aa5dd7cb0ccdaa0cd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,753 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98331d854490abe4e7f478d817104218cb49c28f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" + +namespace ck { + +enum struct PipelineVersion +{ + v1, + v2, +}; + +template +constexpr auto GridwiseGemmPipeline_Selector() +{ + if constexpr(PipelineVer == PipelineVersion::v1) + { + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } + } + else if constexpr(PipelineVer == PipelineVersion::v2) + { + return GridwiseGemmPipeline_v2{}; + } + else + { + std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9097552c2baf236b68370d6d98a771a2eec4e05 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<1> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +template +struct GridwiseGemmPipelineInterwave_v1; + +template <> +struct GridwiseGemmPipelineInterwave_v1<1> +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // block_sync_lds(); // moved into blockwise_gemm + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// Note: 2 stage prefetch not optimized for inter-wave loop scheduler +template <> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +{ +}; + +// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality +template +constexpr auto GridwiseGemmPipeline_v1_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3281b910d3d8782fc18f91ee77ce42aa44fc5ba9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v2 +{ + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2fe55068449377a82b0c4bd12c65f05ea4e96199 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,879 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // TODO - extract following into reduction_blockwise + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + + // Reduction + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa89bff9ee21873fdc57c839eb800be455005a9a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmSplitKMultipleD_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, src of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, src of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, + const int split_k) + { + const auto MRaw = a_grid_desc_m_k.GetLength(I0); + const auto KRaw = a_grid_desc_m_k.GetLength(I1); + + const index_t AK0 = + (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / AK1; + const index_t K = split_k * AK0 * AK1; + const auto KPad = K - KRaw; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(split_k, AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, + const int split_k) + { + const auto NRaw = b_grid_desc_n_k.GetLength(I0); + const auto KRaw = b_grid_desc_n_k.GetLength(I1); + + const index_t BK0 = + (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / BK1; + const index_t K = split_k * BK0 * BK1; + const auto KPad = K - KRaw; + + const auto b_grid_desc_n_kpad = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_n_kpad, + make_tuple(make_unmerge_transform(make_tuple(split_k, BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const EGridDescriptor_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, const int split_k) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + e_grid_desc_m_n, 8, split_k); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_akb_ak0_m_ak1.GetLength(I2); + const auto N = b_grid_desc_bkb_bk0_n_bk1.GetLength(I2); + const auto K = + a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3); + + if(K != b_grid_desc_bkb_bk0_n_bk1.GetLength(I1) * b_grid_desc_bkb_bk0_n_bk1.GetLength(I3)) + { + return false; + } + if(a_grid_desc_akb_ak0_m_ak1.GetLength(I0) != b_grid_desc_bkb_bk0_n_bk1.GetLength(I0)) + { + return false; + } + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(block_work_idx[Number<0>{}] == 0) + { + Run0(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run1(p_a_grid, + p_b_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + template + __device__ static void Run0(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t k_batch_id = block_work_idx[I0]; + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_akb_ak0_m_ak1 = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bkb_bk0_n_bk1 = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_akb_ak0_m_ak1), + decltype(a_block_desc_akb_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_akb_ak0_m_ak1, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_akb_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bkb_bk0_n_bk1), + decltype(b_block_desc_bkb_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bkb_bk0_n_bk1, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bkb_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_akb_ak0_m_ak1, + a_block_desc_akb_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bkb_bk0_n_bk1, + b_block_desc_bkb_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + { + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor, + false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + } + + template + __device__ static void Run1(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t k_batch_id = block_work_idx[I0]; + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_akb_ak0_m_ak1 = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bkb_bk0_n_bk1 = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_akb_ak0_m_ak1), + decltype(a_block_desc_akb_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_akb_ak0_m_ak1, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_akb_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bkb_bk0_n_bk1), + decltype(b_block_desc_bkb_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bkb_bk0_n_bk1, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bkb_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_akb_ak0_m_ak1, + a_block_desc_akb_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bkb_bk0_n_bk1, + b_block_desc_bkb_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + { + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation, + EGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + EDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ecc528a7ed79a337efd02cc5efe9dd11039b3b29 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,653 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94e181cd45470c08f4f7b35b84ff991a56babb82 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -0,0 +1,1068 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + + // TODO ANT: Run layernorm epilogue here +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_bias_grid; + ignore = p_c0_add_grid; + ignore = p_c0_gamma_grid; + ignore = p_c0_beta_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +template +struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + // LDS allocation for reduction workspace + constexpr index_t c_lds_workspace_size = BlockSize; + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size_aligned * sizeof(FloatCShuffle) + + c_lds_workspace_size * sizeof(FloatReduceAcc)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // in order to reduce N dim without elaborate sync across CUs in single kernel, one + // workgroup must span the entire N extent + if(math::integer_divide_ceil(N, NPerBlock) > 1) + { + return false; + } + + // static check: all waves in the workgroups combined must cover whole N extent in order + // to have efficient N-dim reduction + static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, + "condition not met for efficient layernorm"); + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // for bias, beta, gamma + __host__ __device__ static constexpr auto + MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n) + { + const auto N = c0_grid_desc_n.GetLength(I0); + const auto NBlock = N / NPerBlock; + + const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_n, + make_tuple(make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return c0_grid_desc_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_bias_grid_buf = make_dynamic_buffer( + p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_grid_buf = make_dynamic_buffer( + p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_gamma_grid_buf = make_dynamic_buffer( + p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + auto c0_beta_grid_buf = make_dynamic_buffer( + p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0); + + // for broadcasting bias, beta, gamma + const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_nblock_nperblock, + make_tuple(make_insert_transform(I1), + make_insert_transform(I1), + make_pass_through_transform(NBlock), + make_pass_through_transform(NPerBlock)), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // pytorch default + // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + static constexpr FloatReduceAcc epsilon = 1e-5; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // TODO: this should be implemented as a blockwise reduction + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + auto d_reduce_work_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + c_block_size_aligned), + BlockSize); + + // Sum thread workspace + auto d0_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // Squared sum thread workspace + auto d1_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatCShuffle, + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_block_desc_mperblock_nperblock), + tensor_operation::element_wise::PassThrough, + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, + c_reduce_thread_data_idx_begin, + tensor_operation::element_wise::PassThrough{}}; + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + block_sync_lds(); + + // load from LDS and global, add bias + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_bias_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + acc_element_op(out, + c_reduce_thread_buf(i) + + static_cast(c0_thread_buf(i))); + c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) + }); + + c0_add_thread_copy_global_to_vgpr.Run( + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_add_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // add + }); + + // layernorm + { + using ThreadwiseReduceD0 = + ThreadwiseReduction; + using ThreadwiseReduceD1 = + ThreadwiseReduction; + + const auto d0_zeroVal = + ThreadwiseReduceD0::Op::template GetIdentityValue(); + const auto d1_zeroVal = + ThreadwiseReduceD1::Op::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d1_thread_buf(i) = d1_zeroVal; }); + + // reduce sum in VGPR + ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf); + + // reduce squared sum in VGPR + ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); + + // reduce within workgroup + using BlockwiseReduce = PartitionedBlockwiseReduction< + FloatReduceAcc, + BlockSize, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K + Sequence<1, 0>, // ThreadClusterArrangeOrder + reduce::Add, + false>; + + static_for<0, mreduce_per_thread, 1>{}([&](auto i) { + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d0_thread_buf(i)); // blockwise reduced sum + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d1_thread_buf(i)); // blockwise reduced squared sum + }); + + // normalize + const index_t NRaw = + c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] + .GetUpperLengths()[I1]; // TODO: proper handle + + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto dst_offset = + Number{}; + + constexpr auto src_offset = + Number{}; + + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; + }); + }); + + // scaling + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_gamma_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) *= + static_cast(c0_thread_buf(i)); // * gamma + }); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_beta_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // + beta + }); + + block_sync_lds(); + + c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf, + c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf); + + } // end layernorm + + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0_add + c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp new file mode 100644 index 0000000000000000000000000000000000000000..126887cbacd727a7c5a3d2026599e0a9c4d9e2e3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -0,0 +1,983 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v4_no_carry +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v4_no_carry() = default; + + __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_up_diff, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_up_diff[I0]; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod_wrw, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +__host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths) +{ + return Merge_v4_no_carry{low_lengths}; +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + // M0/M1/M1Padding + static constexpr auto M1PerBlock = Number{}; + static constexpr auto M0PerBlock = Number{}; + static constexpr auto M1Padding = Number{}; + + // N0/N1/N1Padding + static constexpr auto N1PerBlock = Number{}; + static constexpr auto N0PerBlock = Number{}; + static constexpr auto N1Padding = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_block_desc_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + M1Padding), + Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_b_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return a_block_desc_b_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_b_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_block_desc_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + N1Padding), + Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_b_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return b_block_desc_b_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return b_block_desc_b_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = math::integer_least_multiple( + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + // const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t num_loop = K0 / K0PerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + + // return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + constexpr index_t KPack = + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + void* p_shared = static_cast(p_shared_block); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2aad7128f0681bcd903c2c11249380e3883c3e53 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -0,0 +1,678 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_skip_b_lds_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + static constexpr index_t WaveSize = 64; + static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % BBlockBufferSize == 0)) + { + return false; + } + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + // TODO move this function into GEMM-pipeline class + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / (BBlockBufferSize * K0PerBlock)) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_unmerge_transform( + make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), + make_unmerge_transform(make_tuple( + N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); + return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; + } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetWaveKNIdx(const index_t thread_id) + { + constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix threadwise copy + constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + I1, + Number{}, // K0PerThread + I1, // NBlockId + Number{}, // repeat + I1, // waves + I1, // NPerXdlops + Number{})); + + using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1>; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 = + decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 1>(a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + ignore = b_element_op; + // B matrix threadwise copy + constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + I1, + Number{}, // K0PerThread + I1, // NBlockId + Number{}, // repeat + I1, // waves + I1, // NPerXdlops + Number{})); + + auto b_thread_buf = generate_tuple( + [&](auto i) { + ignore = i; + return StaticBuffer{}; + }, + Number{}); + + const auto wave_id = GetWaveIdx(); + const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); + +#if 0 + const index_t block_id = get_block_1d_id(); + const index_t thread_id = get_thread_local_1d_id(); + printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} " + "kn id: {%d %d}\n", + block_id, + block_work_idx[I0], + block_work_idx[I1], + thread_id, + wave_id[I0], + wave_id[I1], + wave_id[I2], + wave_k_n_id[I0], + wave_k_n_id[I1]); + printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", + xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0)); +#endif + + auto b_threadwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_multi_index( + 0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1>{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + // gridwise GEMM pipeline + constexpr auto a_block_slice_copy_step = + make_multi_index(K0PerBlock * BBlockBufferSize, 0, 0); + constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); + // preload data to regiester and LDS + { + // Read + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_grid_buf, + b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_buf(Number{})); + b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_thread_slice_copy_step); + }); + + // Initialize C + c_thread_buf.Clear(); + // a data write to lds + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + // main body + if constexpr(HasMainK0BlockLoop) + { + index_t K0BlockMainLoop = + __builtin_amdgcn_readfirstlane(K0 / (BBlockBufferSize * K0PerBlock)); + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + blockwise_gemm.ResetABlockStartWindow(); + block_sync_lds(); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + blockwise_gemm.Run(a_block_buf, b_thread_buf(Number{}), c_thread_buf); + blockwise_gemm.MoveABlockSliceWindow(); + s_nop(); + + b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_grid_buf, + b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_buf(Number{})); + b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_thread_slice_copy_step); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + // move a and b window + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + + i += 1; + } while(i < (K0BlockMainLoop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.ResetABlockStartWindow(); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + blockwise_gemm.Run(a_block_buf, b_thread_buf(Number{}), c_thread_buf); + blockwise_gemm.MoveABlockSliceWindow(); + }); + } + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d1149c0c2e3c80305e6cb415568cb1fa1fbb7a39 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -0,0 +1,557 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_block_desc_k0_n_k1), + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..949d56483669946f12f40adca568d4e2a6b0bb5b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -0,0 +1,616 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0), + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1)))) + { + return; + } + + const index_t k_batch_id = block_work_idx[I0]; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..190194f1eb18945eb0e6dc3334e9e6a5a5248267 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -0,0 +1,721 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + static_cast(p_shared_block), + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t k_batch_id = block_work_idx[I0]; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = static_cast(p_shared_block); + FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ffb2926c877e1569847da43fe08c05e30073dc84 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -0,0 +1,723 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatCShuffle, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_AK0_M_AK1, + typename BGridDesc_BK0_N_BK1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1Value, + index_t BK1Value, + index_t MPerXdl, + index_t NPerXdl, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t k_pack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_shuffle_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7e6dbb3b2ed85f55ec062a7b60cedc7b40337d54 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -0,0 +1,762 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb1e34b98543665717624af2201d9cbfe4596cda --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -0,0 +1,801 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89efea4d6c36a53fdb98895838848d271857debd --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Y = Normalization(X, Beta, Gamma) +template +struct GridwiseNormalizationNaiveVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + using ThreadwiseSumReduce = ThreadwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + x_thread_buf; + + StaticBuffer + gamma_thread_buf; + + StaticBuffer& beta_thread_buf = gamma_thread_buf; + + StaticBuffer + y_thread_buf; + + StaticBuffer& x_square_thread_buf = y_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer + mean_square_thread_buf; + StaticBuffer& var_thread_buf = + mean_square_thread_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); + mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + acc_elementwise_op); + + // Copy x from Cache + // one pass: fwd, second pass: bwd + constexpr auto thread_copy_fwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + // E(x), E[x^2], var(x) + // FIXME: Should not hack the transform from deviceOP + int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + + index_t reducedTiles = 0; + do + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(Number{}) = + x_thread_buf(Number{}) * x_thread_buf(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + + ++reducedTiles; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; + + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + }); + + // y = (x - E[x]) / sqrt(var[x] + epsilon) + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + reducedTiles = 0; + do + { + if constexpr(!SweepOnce) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + } + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + (x_thread_buf(Number{}) - mean_thread_buf(iM)) / + sqrt(var_thread_buf(iM) + epsilon); + + // gamma + y_thread_buf(Number{}) = + y_thread_buf(Number{}) * gamma_thread_buf(Number{}); + }); + }); + + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // beta + y_thread_buf(Number{}) = + y_thread_buf(Number{}) + beta_thread_buf(Number{}); + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + + ++reducedTiles; + } while(reducedTiles < num_k_block_tile_iteration); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7aefd3c066b3b29c2af12bfce8cf98b5b900b755 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Y = Normalization(X, Beta, Gamma) +template +struct GridwiseNormalizationWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto XThreadBufferNumber = Number{}; + static constexpr auto GammaThreadBufferNumber = Number{}; + static constexpr auto BetaThreadBufferNumber = Number{}; + static constexpr auto YThreadBufferNumber = Number{}; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + // FIXME: Should not hack the transform from deviceOP + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + + __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto beta_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + acc_elementwise_op); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + if constexpr(!SweepOnce) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } + + static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..de1ae915920503a086d2cf939ae6e41d27a28838 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, + const OutGridDesc out_grid_desc, + const InDataType* p_in_global, + OutDataType* p_out_global, + const ElementwiseOperation elementwise_op, + const Block2TileMap block_2_tile_map) +{ + __shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()]; + + GridwisePermute::Run(in_grid_desc, + out_grid_desc, + p_in_global, + p_out_global, + p_shared, + elementwise_op, + block_2_tile_map); +} + +template +struct GridwisePermute +{ + static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension()); + static_assert(3 <= InGridDesc::GetNumOfDimension()); + static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim && + SrcVectorDim < InGridDesc::GetNumOfDimension()); + static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim && + DstVectorDim < OutGridDesc::GetNumOfDimension()); + static_assert(SrcVectorDim != DstVectorDim); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThisThreadBlock = ThisThreadBlock; + + struct Block2TileMap + { + static constexpr index_t NumDim = InGridDesc::GetNumOfDimension(); + static_assert(3 <= NumDim); + + static constexpr auto I0 = Number<0>{}; + + Block2TileMap() = delete; + Block2TileMap(const Block2TileMap&) = default; + Block2TileMap(Block2TileMap&&) = delete; + + ~Block2TileMap() = default; + + Block2TileMap& operator=(const Block2TileMap&) = delete; + Block2TileMap& operator=(Block2TileMap&&) = delete; + + explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {} + + __host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const + { + const auto N0 = + math::integer_divide_ceil(desc.GetLength(Number{}), NPerBlock); + const auto H0 = + math::integer_divide_ceil(desc.GetLength(Number{}), HPerBlock); + const auto W0 = + math::integer_divide_ceil(desc.GetLength(Number{}), WPerBlock); + + const index_t grid_size = N0 * H0 * W0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == 1); + + auto block_1d_id = idx_top[I0]; + + const auto N0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), NPerBlock); + const auto H0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), HPerBlock); + const auto W0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), WPerBlock); + + block_1d_id = block_1d_id % (N0 * H0 * W0); + + index_t idx_N0 = block_1d_id / (H0 * W0); + index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0; + index_t idx_W0 = block_1d_id % W0; + + return make_tuple(idx_N0, idx_H0, idx_W0); + } + + private: + const InGridDesc desc_; + }; + + using DefaultBlock2TileMap = Block2TileMap; + + // use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay + __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock() + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + I1)); + } + + // for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions + // into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] -> + // [(d(0) x d(1) x ...), d(N - 2), d(N - 1)] + template + __host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc) + { + constexpr index_t NumDim = GridDesc::GetNumOfDimension(); + static_assert(3 <= NumDim); + + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(generate_tuple( + [&](auto I) { return desc.GetLength(I); }, Number{})), + make_pass_through_transform(desc.GetLength(Number{})), + make_pass_through_transform(desc.GetLength(Number{}))), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{}), + Sequence{}, + Sequence{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return merged_desc; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto in_block_desc_nperblock_hperblock_wperblock = + GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock(); + + return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() * + sizeof(InDataType); + } + + __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc) + { + return DefaultBlock2TileMap{desc}; + } + + __host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc, + const OutGridDesc& out_grid_desc) + { + constexpr index_t NumDim = InGridDesc::GetNumOfDimension(); + + // check if we only swap last 2 dimensions + bool valid = true; + static_for<0, NumDim - 2, 1>{}([&](auto I) { + if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I)) + { + valid = false; + } + }); + + return valid && + (in_grid_desc.GetLength(Number{}) == + out_grid_desc.GetLength(Number{})) && + (in_grid_desc.GetLength(Number{}) == + out_grid_desc.GetLength(Number{})); + } + + template + __device__ static void Run(const InGridDesc in_grid_desc, + const OutGridDesc out_grid_desc, + const InDataType* p_in_global, + OutDataType* p_out_global, + void* __restrict__ p_shared, + const ElementwiseOperation elementwise_op, + const Block2TileMap& block_2_tile_map) + { + auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_desc.GetElementSpaceSize()); + + auto out_global_buf = make_dynamic_buffer( + p_out_global, out_grid_desc.GetElementSpaceSize()); + + // each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock); + + const index_t h_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock); + + const index_t w_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock); + + // create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer + constexpr auto in_block_desc_nperblock_hperblock_wperblock = + GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock(); + + auto in_block_buf = make_dynamic_buffer( + static_cast(p_shared), + in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize()); + + using BlockSliceLengths = Sequence; + using InBlockTransferAccessOrder = Sequence<0, 1, 2>; + + constexpr index_t SrcVectorDimAfterMerge = + SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3); + constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge; + + using ck::tensor_operation::element_wise::PassThrough; + + // merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x + // ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)] + const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc); + + // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS + auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + ElementwiseOperation, + PassThrough, + InMemoryDataOperationEnum::Set, + BlockSliceLengths, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + InDataType, + InDataType, + decltype(in_grid_desc_n_h_w), + decltype(in_block_desc_nperblock_hperblock_wperblock), + InBlockTransferAccessOrder, + InBlockTransferAccessOrder, + SrcVectorDimAfterMerge, + 2, + SrcScalarPerVector, + 1, + 1, + 1, + true, + true>(in_grid_desc_n_h_w, + make_multi_index( + n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid), + PassThrough{}, + in_block_desc_nperblock_hperblock_wperblock, + make_multi_index(0, 0, 0), + PassThrough{}); + + // merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x + // ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)] + const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc); + + // create transposed view of output tensor + const auto out_grid_desc_n_h_w = transform_tensor_descriptor( + out_grid_desc_n_w_h, + make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)), + make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)), + make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); + + // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory + auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + ElementwiseOperation, + PassThrough, + InMemoryDataOperationEnum::Set, + BlockSliceLengths, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + InDataType, + OutDataType, + decltype(in_block_desc_nperblock_hperblock_wperblock), + decltype(out_grid_desc_n_h_w), + InBlockTransferAccessOrder, + InBlockTransferAccessOrder, + 2, + DstVectorDimAfterMerge, + 1, + DstScalarPerVector, + 1, + 1, + true, + true>(in_block_desc_nperblock_hperblock_wperblock, + make_multi_index(0, 0, 0), + PassThrough{}, + out_grid_desc_n_h_w, + make_multi_index( + n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid), + elementwise_op); + + in_global_load.Run(in_grid_desc_n_h_w, + in_global_buf, + in_block_desc_nperblock_hperblock_wperblock, + in_block_buf, + I0); + + out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock, + in_block_buf, + out_grid_desc_n_h_w, + out_global_buf, + I0); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp new file mode 100644 index 0000000000000000000000000000000000000000..901e7aee98a7b6ab1c37b7dd4dc36f796632fef6 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc, + DataType* const __restrict__ p_global, + DataType value) + +{ + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + constexpr auto I0 = Number<0>{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t thread_global_id = block_global_id * BlockSize + thread_local_id; + + StaticBuffer value_buf; + + value_buf(I0) = value; + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + auto global_buf = make_dynamic_buffer( + p_global, grid_1d_buffer_desc.GetElementSpaceSize()); + + if(thread_global_id < grid_1d_buffer_desc.GetElementSize()) + { + auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run( + val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp new file mode 100644 index 0000000000000000000000000000000000000000..88c7b6acfeb9c43ec0e8b5a5ab557f834430fb40 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple, + DataTypePointerTuple p_global_tuple, + DataTypeTuple value_tuple) + +{ + static_assert(NumBuffer == DataTypePointerTuple::Size() && NumBuffer == DataTypeTuple::Size(), + "The tuple size should be same as NumBuffer!"); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + using DataTypePointer = remove_cvref_t; + using DataTypeFromPointer = remove_pointer_t; + using DataType = remove_cvref_t; + + static_assert(is_same::value, + "Types in tuples does not match!"); + }); + + constexpr auto I0 = Number<0>{}; + + const index_t thread_global_id = get_thread_global_1d_id(); + + auto value_buf_tuple = generate_tuple( + [&](auto iB) { + using DataType = remove_cvref_t; + + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + static_for<0, 1, 1>{}([&](auto J) { value_buf_tuple(iB)(J) = value_tuple[iB]; }); + }); + + auto global_buf_tuple = generate_tuple( + [&](auto iB) { + return make_dynamic_buffer( + p_global_tuple(iB), grid_1d_buffer_desc_tuple[iB].GetElementSpaceSize()); + }, + Number{}); + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + using DataType = remove_cvref_t; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + auto threadwise_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc_tuple[iB], make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run(val_buff_desc, + make_tuple(I0), + value_buf_tuple(iB), + grid_1d_buffer_desc_tuple[iB], + global_buf_tuple(iB)); + }); +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0344e68305b1a951cfc41a53666090bd1e0113b6 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, + const GridDesc_M_K out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_k, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); +}; + +template +struct GridwiseSoftmax_mk_to_mk +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (KThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k, + const GridDesc_M_K& out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer + out_thread_buf; + + StaticBuffer max_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // Normally, 0 as invalid element value is adequate since 0 makes no contribution to + // accumulated result. However, in stable softmax, all values 0s or not are subtracted by + // another value_max. As numbers become non-zero, effectively it allows invalid values to + // slip through and contribute to the accumulated result. + // + // The trick here is leveraging the fact that many math functions (add, sub, exp, ...) + // propagate NaNs when operands have NaNs involved. By initialiing invalid element value + // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still + // be identified as an invalid value. We can then discard the invalid values which + // originally failed the bound check during accumulation. This allows to ignore values that + // failed bound check even after multiple math manipulations. + // + // NOTE: reset coordinate after every step because the same threadwise copy will sweep + // through global memory 3 times back and forth + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2( + out_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3( + out_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + constexpr auto in_thread_copy_fwd_step = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto in_thread_copy_bwd_step = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + /// + /// max(x) + /// + using BlockwiseMaxReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Max, + false, // param ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseMaxReduce = + ThreadwiseReduction>; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize()); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); + block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + /// + /// sum(exp(x - max(x))) + /// + using BlockwiseSumReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Add, + false, // ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseSumReduce = + ThreadwiseReduction>; + + reducedTiles = 0; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + // do element-wise pre-reduction operation + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); + }); + }); + + ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + block_sync_lds(); // wait for reading being complete before writing to LDS + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); + block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + /// + /// softmax + /// + reducedTiles = 0; + if(float_equal_zero{}(beta)) + { + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + StaticBuffer + in_prior_dst_buf; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + threadwise_dst_load.Run(out_grid_desc_m_k, + out_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_prior_dst_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3de6aa08c45b9fadbdf88d1554d40824a5f555ce --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" + +namespace ck { + +template +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + __global__ void kernel_sparse_embedding3_forward_layernorm(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc out_grid_desc, + const AccDataType epsilon) +{ + GridwiseSparseEmbedding::Run(p_out, + p_emb_a, + p_emb_b, + p_emb_c, + p_index_a, + p_index_b, + p_index_c, + p_gamma, + p_beta, + out_grid_desc, + epsilon); +} + +template +struct GridwiseSparseEmbedding3ForwardLayernorm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t WaveSize = 64; + + static_assert(BlockSize == RowClusterSize * DimClusterSize, + "Invalid cluster distribution within block"); + static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise"); + + static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, ""); + static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, ""); + + static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize); + static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize); + + static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), ""); + static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks; + static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks; + + using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}))); + + using ThreadwiseWolfordDescReduce = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using ThreadClusterLength = Sequence; + + using BlockwiseWelford = + BlockwiseWelford>; + + __device__ static void Run(OutType* p_out, + const EmbType* p_emb_a, + const EmbType* p_emb_b, + const EmbType* p_emb_c, + const IndexType* p_index_a, + const IndexType* p_index_b, + const IndexType* p_index_c, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc, + const AccDataType epsilon) + { + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + // const auto index_length = out_grid_desc.GetLength(I0); + // const auto emb_dim = out_grid_desc.GetLength(I1); + + constexpr auto thread_cluster_desc = + make_cluster_descriptor(Sequence{}, Sequence<0, 1>{}); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_dim_cluster_id = thread_cluster_idx[I0]; + const auto thread_row_cluster_id = thread_cluster_idx[I1]; + + const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize); + + const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize; + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize; + + constexpr auto thread_buf_size = + DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize; + constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed( + make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize)); + constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize; + constexpr auto mean_var_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize)); + constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize; + constexpr auto gamma_beta_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); + + StaticBuffer in_thread_buf_a; + StaticBuffer in_thread_buf_b; + StaticBuffer in_thread_buf_c; + + StaticBuffer index_buf_a; + StaticBuffer index_buf_b; + StaticBuffer index_buf_c; + + StaticBuffer acc_thread_buf; + + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + vector_type_maker_t emb_vector_a; + vector_type_maker_t emb_vector_b; + vector_type_maker_t emb_vector_c; + + using src_vector_t = typename decltype(emb_vector_a)::type; + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_; + IndexType index_a = index_buf_a[Number{}]; + IndexType index_b = index_buf_b[Number{}]; + IndexType index_c = index_buf_c[Number{}]; + + auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(EmbType) * RowVectorSize; + + int32x4_t emb_res_a = + make_wave_buffer_resource_with_default_range(p_emb_a + index_a * RowPerBlock); + int32x4_t emb_res_b = + make_wave_buffer_resource_with_default_range(p_emb_b + index_b * RowPerBlock); + int32x4_t emb_res_c = + make_wave_buffer_resource_with_default_range(p_emb_c + index_c * RowPerBlock); + emb_vector_a.template AsType()(I0) = + amd_buffer_load_impl(emb_res_a, thread_offset, 0); + emb_vector_b.template AsType()(I0) = + amd_buffer_load_impl(emb_res_b, thread_offset, 0); + emb_vector_c.template AsType()(I0) = + amd_buffer_load_impl(emb_res_c, thread_offset, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + in_thread_buf_a(Number{}) = + emb_vector_a.template AsType()[i_row_vec_]; + in_thread_buf_b(Number{}) = + emb_vector_b.template AsType()[i_row_vec_]; + in_thread_buf_c(Number{}) = + emb_vector_c.template AsType()[i_row_vec_]; + }); + }); + }; + + auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + AccDataType va = + ck::type_convert(in_thread_buf_a(Number{})); + AccDataType vb = + ck::type_convert(in_thread_buf_b(Number{})); + AccDataType vc = + ck::type_convert(in_thread_buf_c(Number{})); + + acc_thread_buf(Number{}) += va + vb + vc; + }); + }); + }; + + auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); + }); + }); + }; + + auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) { + int32x4_t out_res = + make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock); + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + vector_type_maker_t out_vector; + using dst_vector_t = typename decltype(out_vector)::type; + + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto gamma_beta_offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + + auto acc_val = acc_thread_buf[Number{}]; + acc_val = (acc_val - mean_thread_buf(Number{})) / + sqrt(var_thread_buf(Number{}) + epsilon); + acc_val = acc_val * gamma_thread_buf[Number{}] + + beta_thread_buf[Number{}]; + + out_vector.template AsType()(Number{}) = + type_convert(acc_val); + }); + + index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(OutType) * RowVectorSize; + + amd_buffer_store_impl( + out_vector.template AsType()[Number<0>{}], + out_res, + thread_offset, + 0); + }); + }; + + // first load index + ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + // prefer use s_load + index_buf_a(i_idx_) = p_index_a[index_start + i_idx_.value]; + index_buf_b(i_idx_) = p_index_b[index_start + i_idx_.value]; + index_buf_c(i_idx_) = p_index_c[index_start + i_idx_.value]; + }); + + // load gamma/beta + static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) { + vector_type_maker_t gamma_vector; + vector_type_maker_t beta_vector; + + index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(GammaDataType) * RowVectorSize; + index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(BetaDataType) * RowVectorSize; + + int32x4_t gamma_res = make_wave_buffer_resource_with_default_range(p_gamma); + int32x4_t beta_res = make_wave_buffer_resource_with_default_range(p_beta); + + gamma_vector.template AsType()(I0) = + amd_buffer_load_impl( + gamma_res, thread_offset_gamma, 0); + beta_vector.template AsType()(I0) = + amd_buffer_load_impl(beta_res, thread_offset_beta, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + gamma_thread_buf(Number{}) = type_convert( + gamma_vector.template AsType()[Number{}]); + beta_thread_buf(Number{}) = type_convert( + beta_vector.template AsType()[Number{}]); + }); + }); + + static_for<0, thread_buf_size, 1>{}( + [&](auto I) { acc_thread_buf(I) = type_convert(0.0f); }); + + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) { + load_current_sub_row(i_dim_sub, Number<0>{}); + static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) { + load_current_sub_row(i_dim_sub, Number<1>{} + i_row); + accumulate_current_sub_row(i_dim_sub, i_row); + threadwise_welford_sub_row(i_dim_sub, i_row); + }); + accumulate_current_sub_row(i_dim_sub, Number{}); + threadwise_welford_sub_row(i_dim_sub, Number{}); + + // blockwise welford + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_); + }); + + // store + static_for<0, RowSubBlocks, 1>{}( + [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); }); + }); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..188c62d93b0d60317ac139115b79f496612d2ea6 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template > +struct ThreadwiseReduction +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + using Op = OpReduce; + + template + __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); + }); + }); + }; +}; + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template < + typename AccDataType, + typename IndexDataType, + typename SrcThreadDesc_M_K, + typename DstThreadDesc_M, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct ThreadwiseReductionWithIndex +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + template + __device__ static void Reduce(const SrcValueBufferType& src_val_buf, + const SrcIndexBufferType& src_idx_buf, + DstValueBufferType& dst_val_buf, + DstIndexBufferType& dst_idx_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); + }); + }); + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94cdfe01087bc4005d45f11992a6a8f94d42691c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto TK = TKLengths{}[I0]; + constexpr auto TM0 = TMLengths{}[I0]; + constexpr auto TM1 = TMLengths{}[I1]; + constexpr auto TN0 = TNLengths{}[I0]; + constexpr auto TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK, 1>{}([&](auto tk) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +{ + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t TK0 = TKLengths{}[I0]; + constexpr index_t TK1 = TKLengths{}[I1]; + constexpr index_t TM0 = TMLengths{}[I0]; + constexpr index_t TM1 = TMLengths{}[I1]; + constexpr index_t TN0 = TNLengths{}[I0]; + constexpr index_t TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + vector_type a_vec; + vector_type b_vec; + + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e045e3b545a9897f5df9df1948123e86424d7c7b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at +// compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + + static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() && + BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0); + constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1); + constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2); + + constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + if constexpr((Ho % 2 == 0) && (Wo % 2 == 0)) + { + constexpr auto SubHW = 2; + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, SubHW>{}([&](auto h) { + static_for<0, Wo, SubHW>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + + constexpr index_t c0_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + constexpr index_t c1_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); + + constexpr index_t c2_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); + + constexpr index_t c3_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + }); + }); + }); + }); + }); + } + else + { + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, 1>{}([&](auto h) { + static_for<0, Wo, 1>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a1197a1630c12c53afa9513f10fcfa907ba82d3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. Desc is known at compile-time +// 2. Buffer is StaticBuffer +// 3. OriginIdx is known at compile-time +// 4. use #-step +template ::type = false> +struct ThreadwiseTensorSliceSet_v1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + template + __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const + { + static_assert(Desc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value, + "wrong! OriginIdx need to be known at compile-time"); + + // Desc is known at compile-time + constexpr auto desc = remove_cvref_t{}; + + // OriginIdx is known at compile-time + constexpr auto origin_idx = to_multi_index(OriginIdx{}); + + static_ford{}([&](auto access_idx) { + constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); + + constexpr bool is_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); + + constexpr index_t offset = coord.GetOffset(); + + if constexpr(is_valid) + { + buf(Number{}) = initial_value; + } + }); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0f453b025f00803091b200cc7201e47e1d4993e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -0,0 +1,1301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; +} // namespace detail + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx, + const ElementwiseOperation& element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_vector.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; // namespace ThreadwiseTensorSliceTransfer_v1r3 + +// Assume: +// 1. src: +// 1. SrcDesc is not known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_slice_origin_idx is not known at compile-time +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. dst_slice_origin_idx is known at compile-time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2 +{ + static_assert((InvalidElementAsNaN && !std::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + if constexpr(InvalidElementAsNaN) + { + dst_buf(Number{}) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; +}; // namespace ck + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to buffer_ + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); + + buffer_(Number{}) = src_tmp_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + vector_type_maker_t dst_tmp_vector; + + // copy data from buffer_ to dst_tmp_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); + + dst_tmp_vector.template AsType()(i) = + type_convert(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_tmp_vector)::type; + + // copy data from dst_tmp_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_tmp_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + + // apply type convert + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + __device__ void SetSrcCoord(const Index& src_ref_idx) + { + src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx); + } + + private: + SrcCoord src_ref_coord_; +}; + +// Do NOT involve any tensor coordinates with StaticBuffer +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic( + const ElementwiseOperation& element_op) + : element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_buf(Number{}) = type_convert(v); + }); + }); + } + + ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bb28c194f4b7a2f371d2be00cb00fc4fe3c17d3f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -0,0 +1,794 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + SrcData src_v; + + src_element_op_(src_v, src_vector_container.template AsType()[i]); + + src_vector_container.template AsType()(i) = src_v; + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a73466efa443af294f2062ee5173cf548880c3d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp @@ -0,0 +1,886 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); + using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); + using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), + dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, + const Dst0Desc& dst0_desc, + const Dst1Desc& dst1_desc, + const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); + dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // apply SrcElementwiseOperation on src_vector_container + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + src_vector_container.template AsType()(i) = + src_element_op_(src_vector_container.template AsType()[i]); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_.template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // TODO type_convert is not used yet!!!!! + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + // TODO type_convert is not used yet!!!!! + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const Dst0Desc& dst0_desc, + const Dst0Buffer& dst0_buf, + const Dst1Desc& dst1_desc, + const Dst1Buffer& dst1_buf) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, forward_step_idx); + }, + Number{}); + + // make forward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst0 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst0_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst0_desc, backward_step_idx); + }, + Number{}); + + // make backward steps: dst1 + // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same + // DstScalarPerVector + // TODO: fix this + const auto dst1_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst1_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + // apply DstElementwiseOperation on dst_vector_container + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + dst_vector_container.template AsType()(i) = + dst_element_op_(dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + // TODO: BUG: should start at 1 + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Dst0Desc dst0_desc, + const Dst1Desc dst1_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); + move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + StaticTensorTupleOfVectorBuffer + src_thread_scratch_; + + StaticTensorTupleOfVectorBuffer + dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e8a23930bbb91677ee18bab216af6c45de72e4c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f13da341f9b2d36d992b1cd1835b43c89dcabd1f --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -0,0 +1,614 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v5r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 && + SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0, + "wrong!"); + }); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector to buffer_ + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx); + + buffer_(Number{}) = + src_vector.template AsType()[Number{}]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // tensor descriptor for dst_vector + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(dst_vector_tensor_lengths, + DstVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + DstVectorTensorContiguousDimOrder{}); + + constexpr auto dst_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); + + // dst access order and lengths + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + vector_type_maker_t dst_vector; + + // copy data from buffer_ to dst_vector (also cast from SrcData to DstData) + static_ford{}([&](auto dst_vector_idx_) { + constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx); + + constexpr index_t dst_vector_offset = + dst_vector_desc.CalculateOffset(dst_vector_idx); + + dst_vector.template AsType()(Number{}) = + type_convert(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_vector)::type; + + // copy data from dst_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c91cd9ca8f86df37fee6dfdf597be1cefe2d683 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..68bc2726f4be6fa1b08fe9213e0db9440526f975 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0f5fb88b04540dbf355fb0f02998303b7680c193 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc, + const Index& src2_slice_origin_idx) + { + src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using src2_vector_type = vector_type_maker_t; + using src2_vector_t = typename src2_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + const bool is_src2_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto src2_vector_container = src2_vector_type{ + src2_buf.template Get(src2_coord_.GetOffset(), is_src2_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i], + src2_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(Src2ResetCoordinateAfterRun) + { + const auto src2_reset_step = + make_tensor_coordinate_step(src2_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, + const Index& src2_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src2ResetCoordinateAfterRun + ? src2_slice_origin_step_idx + : src2_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); + + move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + Src2Coord src2_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2eb1b0ee90a446ee3b14354aff018275cf7809f4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = + SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + auto generate_vectors = [&](auto data_types) { + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + }; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(SrcDatas{}); + auto dst_vectors = generate_vectors(DstDatas{}); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + // apply pointwise function + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12ba2c5381311ab45ffad4be793651d176f5fc59 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math_v2.hpp" + +namespace ck { + +// Assume +// 1) XDesc is known at compile-time +// 2) MeanVarDesc is known at compile-time +// 3) XBuffer is static buffer +// 4) MeanBuffer is static buffer +// 5) VarBuffer is static buffer +template +struct ThreadwiseWelford +{ + static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{}; + static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{}; + + static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{}); + + static_assert(thread_x_length_m == thread_mean_var_length_m, + "lengths of source and mean/var buffer must match!"); + + __device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {} + + __device__ inline void Update(T& mean, T& var, T x) + { + using ck::math::isnan; + + if(isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + template + __device__ void + Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m) + { + // FIXME - Better naming for var_buf_m + + static_for<0, thread_x_length_k, 1>{}([&](auto iK) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + static_for<0, thread_x_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = + mean_var_thread_desc_m.CalculateOffset(make_tuple(iM)); + + constexpr auto in_offset = + x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + Update(mean_buf_m(Number{}), + var_buf_m(Number{}), + x_buf_m_k[Number{}]); + }); + } + }); + }; + + int cur_count_; + int max_count_; +}; + +template +struct ThreadwiseWelfordMerge +{ + static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + __device__ static void + Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b) + { + int count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + template + __device__ static void Run(const SrcMeanBufferType& src_mean_buf, + const SrcVarBufferType& src_var_buf, + const SrcCountBufferType& src_count_buf, + DstMeanBufferType& dst_mean_buf, + DstVarBufferType& dst_var_buf, + DstCountBufferType& dst_count_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Merge(dst_mean_buf(iM), + dst_var_buf(iM), + dst_count_buf(iM), + src_mean_buf[Number{}], + src_var_buf[Number{}], + src_count_buf[Number{}]); + }); + + if constexpr(GetActualVariance) + { + dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM]; + }; + }); + }; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d53f0d81620141ff6dfc6cb965209709341aaf8 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -0,0 +1,851 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_xdlops.hpp" + +namespace ck { + +enum struct MfmaInstr +{ + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x8i8, + mfma_i32_16x16x16i8, + mfma_f64_16x16x4f64 +}; + +template +struct mfma_type; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + } +}; + +// treat 4x4x1 as a single-blk 4x64 mfma +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8bf16_1k::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16bf16_1k::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x8bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x16i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 1; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f64_16x16x4f64::Run(a, b, reg_c); + } +}; + +template +struct MfmaSelector +{ + template + static constexpr auto GetMfma(); + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f64_16x16x4f64; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + static constexpr auto GetMfma() + { +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x8i8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x16i8; + } + + static constexpr auto selected_mfma = mfma_type()>{}; + + __host__ __device__ constexpr MfmaSelector() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr bool IsABroadcast() + { + static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); + return true; + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } +}; + +template +struct XdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); + } + + __host__ __device__ constexpr XdlopsGemm() + { + static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || + NPerXdlops == 64, + "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || + MPerXdlops == 64, + "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + // transposed XDL output supporting C' = B' * A' + // M2_N2 -> M2_N2_N3_N4 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "base base_type must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; + + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto mfma = MfmaSelector{}; + + static constexpr auto mfma_instr = mfma.selected_mfma; + + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5fc11d9158ab9a1261c1285c2cd9f5208d7afbf3 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +static auto MakeGridDescriptorPair(const std::vector& gs_ms_ns_lengths_vec, + const std::vector& gs_ms_ns_strides_vec) +{ + if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + { + throw std::runtime_error("wrong! dimension must match input lengths"); + } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + static auto MakeAGridDescriptorPair(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + static auto MakeAGridDescriptor_G_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B (alias of B0) + // + static auto MakeB0GridDescriptorPair(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeB0GridDescriptor_G_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + static auto MakeB0GridDescriptor_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B1 + // + static auto MakeB1GridDescriptorPair(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + static auto MakeB1GridDescriptor_G_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + static auto MakeB1GridDescriptor_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // C + // + static auto MakeCGridDescriptorPair(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13d0a28cfe5b2db14a83a4b81e0bb379aabbe137 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -0,0 +1,583 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { +namespace tensor_operation { + +template < + index_t NDimSpatial, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, + index_t AK1, + index_t BK1, + index_t GemmMPerBlock, + index_t GemmNPerBlock, + bool DoPadGemmM, + bool DoPadGemmN> +struct TransformConvBwdDataToGemm_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template , + bool>::type = false> + static auto MakeADescriptor_AK0_M_AK1( + const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& /* input_right_pads */, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t K = wei_g_k_c_xs_lengths[1]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t AK0 = K / AK1; + + // assume packed + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(AK0, AK1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(AK0, AK1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(AK1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + } + + template , + bool>::type = false> + static auto MakeBDescriptor_BK0_N_BK1( + const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& /* input_left_pads */, + const std::array& /* input_right_pads */, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t K = wei_g_k_c_xs_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t BK0 = K / BK1; + + // assume packed + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // B: weight tensor + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + make_tuple(BK0, GemmNPerBlock, BK1), + Sequence{}); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), + make_pass_through_transform(C), + make_pass_through_transform(BK1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + make_tuple( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1), + Sequence{}); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& in_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + // assume strided + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[2])); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + } + + // for input bias + template || + is_same_v), + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& /* input_right_pads */, + const std::array& /* tildes */) + { + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + const auto in_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // bias tensor + const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1)); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b5e64b66cf292fbcf4979d95aa4a5b1fa3e3a8b --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -0,0 +1,880 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvFwdToGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_di_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const auto in_n_di_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = ck::accumulate_n( + b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto wei_gemmn_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + + return wei_gemmn_gemmk_desc; + } + + template < + typename BLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = ck::accumulate_n( + b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const index_t KStride = b_g_k_c_xs_strides[1]; + const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + + return out_gemmm_gemmn_desc; + } + + template < + typename CLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const auto KStride = I1; + const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); + + return out_gemmm_gemmn_desc; + } + + // for output bias + template || + is_same_v, + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); + + return out_gemmm_gemmn_desc; + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_address_space.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_address_space.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f1525914cdee16ee6cbed51297f0008d46f6085 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_address_space.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "c_style_pointer_cast.hpp" + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +namespace ck { + +enum struct AddressSpaceEnum +{ + Generic, + Global, + Lds, + Sgpr, + Vgpr, +}; + +template +__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) +{ + // cast a pointer in "Constant" address space (4) to "Generic" address space (0) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +template +__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) +{ + // cast a pointer in "Generic" address space (0) to "Constant" address space (4) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_buffer_addressing.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_buffer_addressing.hpp new file mode 100644 index 0000000000000000000000000000000000000000..79295356df8715b3e95dbe3b2238b802e6466f50 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_buffer_addressing.hpp @@ -0,0 +1,1177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +template +union BufferResource +{ + __device__ constexpr BufferResource() : content{} {} + + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +template +__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +// buffer load i8 +__device__ int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +__device__ int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +__device__ int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +// buffer load i16 +__device__ bhalf_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +__device__ bhalf2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +__device__ bhalf4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 +__device__ int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +__device__ int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +__device__ int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +// buffer load fp16 +__device__ half_t +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ half2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ half4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// buffer load fp32 +__device__ float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ float2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +__device__ float4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// buffer store i8 +__device__ void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +// buffer store i16 +__device__ void +llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 +__device__ void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// buffer store fp16 +__device__ void +llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); + +// buffer store fp32 +__device__ void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// buffer atomic-add fp32 +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); + +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + +template +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + // use fp32 load to mimic fp64 load + if constexpr(N == 1) + { + const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); + } + else if constexpr(N == 2) + { + const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); + } + else if constexpr(N == 4) + { + const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + const float4_t f32_1 = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = bit_cast(f32_0); + tmp.AsType()(Number<1>{}) = bit_cast(f32_1); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + // use fp32 load to mimic fp16 load + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + 0); + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 8) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 16) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return bit_cast(tmp); +#endif + } + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + // use fp32 store to mimic fp64 store + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { +#if 0 + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); +#else + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bhalf_t), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } +} + +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; + + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_element_valid ? tmp : vector_t(0); +#endif +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_element_valid ? tmp : vector_t(customized_value); +} + +// buffer_store requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_inline_asm.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_inline_asm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..82bf2a5eb576c838f6f3d3c54b4b74c3838896d4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_inline_asm.hpp @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_INLINE_ASM_HPP +#define CK_AMD_INLINE_ASM_HPP + +#include "data_type.hpp" +#include "c_style_pointer_cast.hpp" + +// TODO: deprecate all amd_assembly_outer_product_xxx + +namespace ck { + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_fmac_f32 %0, %2, %3 \n \ + v_fmac_f32 %1, %2, %4 \n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4( + float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) +{ + asm volatile("\n \ + v_fmac_f32 %0, %4, %5 \n \ + v_fmac_f32 %1, %4, %6 \n \ + v_fmac_f32 %2, %4, %7 \n \ + v_fmac_f32 %3, %4, %8 \n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %3, %0\n \ + v_dot2_f32_f16 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %4, %0\n \ + v_dot2_f32_f16 %1, %2, %6, %1\n \ + v_dot2_f32_f16 %0, %3, %5, %0\n \ + v_dot2_f32_f16 %1, %3, %7, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "0"(c0), + "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half2_t a, + half2_t b0, + half2_t b1, + half2_t b2, + half2_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %5, %0\n \ + v_dot2_f32_f16 %1, %4, %6, %1\n \ + v_dot2_f32_f16 %2, %4, %7, %2\n \ + v_dot2_f32_f16 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half4_t a, + half4_t b0, + half4_t b1, + half4_t b2, + half4_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + const half2_t* p_b2_half2 = c_style_pointer_cast(&b2); + const half2_t* p_b3_half2 = c_style_pointer_cast(&b3); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %6, %0\n \ + v_dot2_f32_f16 %1, %4, %8, %1\n \ + v_dot2_f32_f16 %2, %4, %10, %2\n \ + v_dot2_f32_f16 %3, %4, %12, %3\n \ + v_dot2_f32_f16 %0, %5, %7, %0\n \ + v_dot2_f32_f16 %1, %5, %9, %1\n \ + v_dot2_f32_f16 %2, %5, %11, %2\n \ + v_dot2_f32_f16 %3, %5, %13, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "v"(p_b2_half2[0]), + "v"(p_b2_half2[1]), + "v"(p_b3_half2[0]), + "v"(p_b3_half2[1]), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +} + +__device__ void amd_assembly_outer_product_1x4(half8_t a, + half8_t b0, + half8_t b1, + half8_t b2, + half8_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + + // TODO remove pointer casting + const half4_t* p_a_half4 = c_style_pointer_cast(&a); + const half4_t* p_b0_half4 = c_style_pointer_cast(&b0); + const half4_t* p_b1_half4 = c_style_pointer_cast(&b1); + const half4_t* p_b2_half4 = c_style_pointer_cast(&b2); + const half4_t* p_b3_half4 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); +} + +__device__ void amd_assembly_outer_product_1x4(half16_t a, + half16_t b0, + half16_t b1, + half16_t b2, + half16_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half8_t* p_a_half8 = c_style_pointer_cast(&a); + const half8_t* p_b0_half8 = c_style_pointer_cast(&b0); + const half8_t* p_b1_half8 = c_style_pointer_cast(&b1); + const half8_t* p_b2_half8 = c_style_pointer_cast(&b2); + const half8_t* p_b3_half8 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %2, %3, %0\n \ + v_dot4_i32_i8 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "0"(c0), + "1"(c1)); +#else + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); +#endif +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(int8x4_t a, + int8x4_t b0, + int8x4_t b1, + int8x4_t b2, + int8x4_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %4, %5, %0\n \ + v_dot4_i32_i8 %1, %4, %6, %1\n \ + v_dot4_i32_i8 %2, %4, %7, %2\n \ + v_dot4_i32_i8 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "v"(bit_cast(b2)), + "v"(bit_cast(b3)), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +#else + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); +#endif +} + +__device__ void amd_assembly_outer_product_1x4(int8x8_t a, + int8x8_t b0, + int8x8_t b1, + int8x8_t b2, + int8x8_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); +} + +__device__ void amd_assembly_outer_product_1x4(int8x16_t a, + int8x16_t b0, + int8x16_t b1, + int8x16_t b2, + int8x16_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I2], + vector_type{b0}.AsType()[I2], + vector_type{b1}.AsType()[I2], + vector_type{b2}.AsType()[I2], + vector_type{b3}.AsType()[I2], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I3], + vector_type{b0}.AsType()[I3], + vector_type{b1}.AsType()[I3], + vector_type{b2}.AsType()[I3], + vector_type{b3}.AsType()[I3], + c0, + c1, + c2, + c3); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_llvm_intrinsic.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_llvm_intrinsic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01e77d7be897085b8c5816fe9e20fe924b859ba7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_llvm_intrinsic.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_LLVM_INTRINSIC_HPP +#define CK_AMD_LLVM_INTRINSIC_HPP + +#include "data_type.hpp" + +namespace ck { + +__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_wmma.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_wmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..752876a76920cb401b9334b3fac627e0c0895510 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_wmma.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_WMMA_HPP +#define CK_AMD_WMMA_HPP + +#include "data_type.hpp" +// TODO: Add arch limitation +namespace ck { + +// wave32 only +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w32; + +template +struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w32; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/amd_xdlops.hpp b/3rdparty/composable_kernel/include/ck/utility/amd_xdlops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b4be0cbee70cbe0e6f6827b9c1f74fb9125e2a39 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/amd_xdlops.hpp @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "data_type.hpp" + +namespace ck { + +// fp32 +template +struct intrin_mfma_f32_32x32x1f32; + +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// fp16 +template +struct intrin_mfma_f32_32x32x4f16; + +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// bfp16 +template +struct intrin_mfma_f32_32x32x8bf16_1k; + +template <> +struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16bf16_1k; + +template <> +struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4bf16; + +template <> +struct intrin_mfma_f32_32x32x4bf16<32, 32> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x8bf16; + +template <> +struct intrin_mfma_f32_16x16x8bf16<16, 16> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_i32_32x32x8i8; + +template <> +struct intrin_mfma_i32_32x32x8i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x16i8; + +template <> +struct intrin_mfma_i32_16x16x16i8<16, 16> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f64_16x16x4f64; + +template <> +struct intrin_mfma_f64_16x16x4f64<16, 16> +{ + template + __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) + { +#ifdef __gfx90a__ + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/array.hpp b/3rdparty/composable_kernel/include/ck/utility/array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..370a457fe9d97a621e6f91e78f14733398ebf756 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/array.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_ARRAY_HPP +#define CK_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +template +struct Array +{ + using type = Array; + using data_type = TData; + + TData mData[NSize]; + + __host__ __device__ static constexpr index_t Size() { return NSize; } + + __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } + + __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } + + __host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); } + + __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } +}; + +// empty Array +template +struct Array +{ + using type = Array; + using data_type = TData; + + __host__ __device__ static constexpr index_t Size() { return 0; } +}; + +template +__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) +{ + using data_type = remove_cvref_t; + return Array{std::forward(x), std::forward(xs)...}; +} + +// make empty array +template +__host__ __device__ constexpr auto make_array() +{ + return Array{}; +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/array_multi_index.hpp b/3rdparty/composable_kernel/include/ck/utility/array_multi_index.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9b8d5b95e9f6241b574104313b5b8fa2984ab18d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/array_multi_index.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_ARRAY_MULTI_INDEX_HPP +#define CK_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +__host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); + return r; +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/c_style_pointer_cast.hpp b/3rdparty/composable_kernel/include/ck/utility/c_style_pointer_cast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e8b0081587afcb2db159a05ecd5fb940def68ff --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/c_style_pointer_cast.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_C_STYLE_POINTER_CAST_HPP +#define CK_C_STYLE_POINTER_CAST_HPP + +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { + +template && is_pointer_v, bool>::type = false> +__host__ __device__ PY c_style_pointer_cast(PX p_x) +{ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wcast-align" + return (PY)p_x; // NOLINT(old-style-cast, cast-align) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/common_header.hpp b/3rdparty/composable_kernel/include/ck/utility/common_header.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1378bbe448e2e1862fed04596fba7f4b106d03b7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/common_header.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/array.hpp" +#include "ck/utility/container_helper.hpp" +#include "ck/utility/statically_indexed_array.hpp" +#include "ck/utility/container_element_picker.hpp" +#include "ck/utility/multi_index.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" +#include "ck/utility/is_known_at_compile_time.hpp" +#include "ck/utility/transpose_vectors.hpp" +#include "ck/utility/inner_product.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/debug.hpp" + +#include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/generic_memory_space_atomic.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/synchronization.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/static_buffer.hpp" +#include "ck/utility/dynamic_buffer.hpp" + +// TODO: remove this +#if CK_USE_AMD_INLINE_ASM +#include "ck/utility/amd_inline_asm.hpp" +#endif + +#ifdef CK_USE_AMD_MFMA +#include "ck/utility/amd_xdlops.hpp" +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/container_element_picker.hpp b/3rdparty/composable_kernel/include/ck/utility/container_element_picker.hpp new file mode 100644 index 0000000000000000000000000000000000000000..abc5185e04a5079edb95fe15bfcd6788256d384a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/container_element_picker.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP +#define CK_CONTAINER_ELEMENT_PICKER_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ContainerElementPicker +{ + using type = ContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ContainerElementPicker() = delete; + + __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr auto& At(Number i) + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray(IP); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + private: + Arr& mArray; +}; + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ConstantContainerElementPicker +{ + using type = ConstantContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ConstantContainerElementPicker() = delete; + + __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + private: + const Arr& mArray; +}; + +template +__host__ __device__ constexpr auto operator+=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto operator-=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks) +{ + return ContainerElementPicker(a); +} + +template +__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks) +{ + return ConstantContainerElementPicker(a); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/container_helper.hpp b/3rdparty/composable_kernel/include/ck/utility/container_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c8b02bc5acaf8542bf028fff39c42061db3b3296 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/container_helper.hpp @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_CONTAINER_HELPER_HPP +#define CK_CONTAINER_HELPER_HPP + +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "array.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) +{ + Array r; + + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + + r(Number{}) = x; + + return r; +} + +template +__host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +__host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_array(old_array[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const Array& old_array, Sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[Number{}]...); +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return Sequence::At(Number{})...>{}; +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, + Sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + Number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, Number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + Number{}, Number{}, Number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); + } + else + { + return init; + } +} +#endif + +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Sequence& seq, Reduce f, Number) +{ + return reverse_exclusive_scan_sequence(seq, f, Number{}); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm4.1 compiler would crash with recursive lambda +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, Number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( + const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, Number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +__host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +__host__ __device__ constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +__host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return make_array(arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + return make_tuple(tup[Number{}]...); +} + +template +__host__ __device__ constexpr void +set_container_subset(Array& y, Sequence picks, const Array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr void +set_container_subset(Tuple& y, Sequence picks, const Tuple& x) +{ + static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) +{ + using Seq = Sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::At(i); + return Number{}; + }, + Seq::Size()); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/data_type.hpp b/3rdparty/composable_kernel/include/ck/utility/data_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40ee8b617e272ac2fa436cadcd6aad76a93720e7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/data_type.hpp @@ -0,0 +1,1059 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/statically_indexed_array.hpp" + +namespace ck { + +using bhalf_t = ushort; +using half_t = _Float16; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using int4_t = _BitInt(4); +#endif + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +// scalar_type +template +struct scalar_type; + +// is_scalar_type +template +struct is_scalar_type +{ + static constexpr bool value = (scalar_type>::vector_size == 1); +}; + +// has_same_scalar_type +template +using has_same_scalar_type = is_same>::type, + typename scalar_type>::type>; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = double; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bhalf_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct scalar_type +{ + using type = int4_t; + static constexpr index_t vector_size = 1; +}; +#endif + +// +template +struct vector_type +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } +}; + +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// convert bfp16 to fp32 +template <> +inline __host__ __device__ constexpr float type_convert(bhalf_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + + return u.fp32; +} + +// convert fp32 to bfp16 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + + return uint16_t(u.int32 >> 16); +} + +template +struct NumericLimits +{ + __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } + + __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } + + __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } + + __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/debug.hpp b/3rdparty/composable_kernel/include/ck/utility/debug.hpp new file mode 100644 index 0000000000000000000000000000000000000000..593bbb711672f7dc6e3db22d877bf249b5dff56d --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/debug.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef UTILITY_DEBUG_HPP +#define UTILITY_DEBUG_HPP + +namespace ck { +namespace debug { + +namespace detail { +template +struct PrintAsType; + +template +struct PrintAsType::value>::type> +{ + using type = float; + __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast(p)); } +}; + +template <> +struct PrintAsType +{ + using type = float; + __host__ __device__ static void Print(const ck::half_t& p) + { + printf("%.3f ", static_cast(p)); + } +}; + +template +struct PrintAsType::value>::type> +{ + using type = int; + __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast(p)); } +}; +} // namespace detail + +// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer +// and the number of elements. Can optionally specify strides between elements and how many bytes' +// worth of data per row. +// +// Usage example: +// +// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize())); +// +template +__device__ void print_shared(T const* p_shared, index_t num_elements) +{ + constexpr index_t row_elements = row_bytes / sizeof(T); + static_assert((element_stride >= 1 && element_stride <= row_elements), + "element_stride should between [1, row_elements]"); + + index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + index_t tid = + (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x; + + __syncthreads(); + + if(tid == 0) + { + printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", + wgid, + row_bytes, + element_stride); + for(index_t i = 0; i < num_elements; i += row_elements) + { + printf("elem %5d: ", i); + for(index_t j = 0; j < row_elements; j += element_stride) + { + detail::PrintAsType::Print(p_shared[i + j]); + } + + printf("\n"); + } + printf("\n"); + } + + __syncthreads(); +} + +} // namespace debug +} // namespace ck + +#endif // UTILITY_DEBUG_HPP diff --git a/3rdparty/composable_kernel/include/ck/utility/dynamic_buffer.hpp b/3rdparty/composable_kernel/include/ck/utility/dynamic_buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c6f0d299ef3c35c63ad97424888545edd2874914 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/dynamic_buffer.hpp @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "enable_if.hpp" +#include "c_style_pointer_cast.hpp" +#include "amd_buffer_addressing.hpp" +#include "generic_memory_space_atomic.hpp" + +namespace ck { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, t_per_x>( + p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::Set) + { + this->template Set(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd) + { + this->template AtomicAdd(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax) + { + this->template AtomicMax(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template Get(i, is_valid_element); + this->template Set(i, is_valid_element, x + tmp); + // tmp += x; + // this->template Set(i, is_valid_element, tmp); + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && + is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + using scalar_t = typename scalar_type>::type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, int32_t> || + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; +#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{p, element_space_size}; +} + +template < + AddressSpaceEnum BufferAddressSpace, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/enable_if.hpp b/3rdparty/composable_kernel/include/ck/utility/enable_if.hpp new file mode 100644 index 0000000000000000000000000000000000000000..297434b0dddd5f1a680176c3bd099bedfda8dff4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/enable_if.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/functional.hpp b/3rdparty/composable_kernel/include/ck/utility/functional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08e730782f386cf5788c64bc04a008dc6cb37b28 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/functional.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" + +namespace ck { + +// TODO: right? wrong? +struct forwarder +{ + template + __host__ __device__ constexpr T&& operator()(T&& x) const + { + return static_cast(x); + } +}; + +struct swallow +{ + template + __host__ __device__ constexpr swallow(Ts&&...) + { + } +}; + +template +struct logical_and +{ + constexpr bool operator()(const T& x, const T& y) const { return x && y; } +}; + +template +struct logical_or +{ + constexpr bool operator()(const T& x, const T& y) const { return x || y; } +}; + +template +struct logical_not +{ + constexpr bool operator()(const T& x) const { return !x; } +}; + +// Emulate if constexpr +template +struct static_if; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F f) const + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + return Type{}; + } + + template + __host__ __device__ static void Else(F) + { + } +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F) const + { + return Type{}; + } + + template + __host__ __device__ static void Else(F f) + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + } +}; + +template +struct conditional; + +template +struct conditional +{ + using type = X; +}; + +template +struct conditional +{ + using type = Y; +}; + +template +using conditional_t = typename conditional::type; + +// z = predicate ? x : y +template +constexpr auto conditional_expr(X&& x, Y&& y) +{ + if constexpr(predicate) + { + return std::forward(x); + } + else + { + return std::forward(y); + } +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/functional2.hpp b/3rdparty/composable_kernel/include/ck/utility/functional2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f125ca4c944777d02ae1083334bec4602ad68c4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/functional2.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { + +namespace detail { + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + swallow{(f(Number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(Number) +template +struct static_for +{ + __host__ __device__ constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + __host__ __device__ constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/functional3.hpp b/3rdparty/composable_kernel/include/ck/utility/functional3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..06b67ef7e3fdec25a1ab2e18e23c904342233ebe --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/functional3.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/multi_index.hpp" + +namespace ck { + +namespace detail { + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct static_ford_impl +{ + __host__ __device__ constexpr static_ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Sequence<...>) + // CurrentOrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::PushBack(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(Sequence<...>) + // OrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::ReorderGivenOld2New(Orders{})); + } +}; + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct ford_impl +{ + __host__ __device__ constexpr ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + for(index_t i = 0; i < RemainLengths::Front(); ++i) + { + ford_impl{}( + f, container_push_back(current_ordered_id, i)); + } + } +}; + +template +struct ford_impl, Orders> +{ + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + // retrive unordered Id + f(container_reorder_given_old2new(current_ordered_id, Orders{})); + } +}; + +} // namespace detail + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + __host__ __device__ constexpr static_ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Sequence<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + detail::static_ford_impl{}(f, Sequence<>{}); + } +}; + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which ford will loop +// over each +// dimension +template ::type> +struct ford +{ + __host__ __device__ constexpr ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Array<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + + for(index_t i = 0; i < ordered_lengths.Front(); ++i) + { + detail::ford_impl{}(f, + make_multi_index(i)); + } + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/functional4.hpp b/3rdparty/composable_kernel/include/ck/utility/functional4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6eeaf15c9b7ac283f7fcac96c195d28f179b6065 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/functional4.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_FUNCTIONAL4_HPP +#define CK_FUNCTIONAL4_HPP + +#include "sequence.hpp" +#include "tuple.hpp" +#include "array.hpp" + +namespace ck { + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x) const + { + return std::forward(f)(std::forward(x).At(Number{})...); + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, Sequence> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const + { + return std::forward(f)(std::forward(x).At(Number{})..., + std::forward(y).At(Number{})...); + } +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/generic_memory_space_atomic.hpp b/3rdparty/composable_kernel/include/ck/utility/generic_memory_space_atomic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a1ca966521710ecdb4ceaad9a83457031f4c508 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/generic_memory_space_atomic.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +__device__ X atomic_add(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float atomic_add(float* p_dst, const float& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ double atomic_add(double* p_dst, const double& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_max explicit for +// each datatype. + +template +__device__ X atomic_max(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_max(int32_t* p_dst, const int32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ uint32_t atomic_max(uint32_t* p_dst, const uint32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float atomic_max(float* p_dst, const float& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ double atomic_max(double* p_dst, const double& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float2_t atomic_max(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicMax(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicMax(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/get_id.hpp b/3rdparty/composable_kernel/include/ck/utility/get_id.hpp new file mode 100644 index 0000000000000000000000000000000000000000..44ff438155d2a9c9ba9b0925d4e9fe95c1fa6bce --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/get_id.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +__host__ __device__ constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +__device__ index_t get_grid_size() { return gridDim.x; } + +__device__ index_t get_block_size() { return blockDim.x; } + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/ignore.hpp b/3rdparty/composable_kernel/include/ck/utility/ignore.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac33cbf9a508f5e6402d17d6285ae34a8196f6eb --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/ignore.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/inner_product.hpp b/3rdparty/composable_kernel/include/ck/utility/inner_product.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5da1ec9bcc55ab959c0b5400a274390fc8fc66c4 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/inner_product.hpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +template +__device__ void inner_product(const TA& a, const TB& b, TC& c); + +template <> +__device__ void inner_product(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) + asm volatile("\n \ + v_mac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +template <> +__device__ void +inner_product(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + c+=static_cast(a*b); +} + +template <> +__device__ void inner_product(const half2_t& a, const half2_t& b, float& c) +{ +#if defined(CK_USE_AMD_V_DOT2_F32_F16) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_sdot2(a, b, c, false); +#endif +#else + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 2, 1>{}([&](auto i) { + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void inner_product(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void inner_product(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void +inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) +{ +#if defined(CK_USE_AMD_V_DOT4_I32_I8) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); +#endif +#else + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 4, 1>{}([&](auto i) { + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void +inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/integral_constant.hpp b/3rdparty/composable_kernel/include/ck/utility/integral_constant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9aab4e24214a884bdfb391e0f9040bb3af2630ec --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/integral_constant.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +struct integral_constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +template +__host__ __device__ constexpr auto operator+(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator-(integral_constant, integral_constant) +{ + static_assert(Y <= X, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator*(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator/(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator%(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/is_known_at_compile_time.hpp b/3rdparty/composable_kernel/include/ck/utility/is_known_at_compile_time.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8198154422e5d5228b79c6539bddcd5070d1e25c --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/is_known_at_compile_time.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "integral_constant.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time>::value & r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/magic_division.hpp b/3rdparty/composable_kernel/include/ck/utility/magic_division.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a5e8e9216519074d6542d379512c5b8d13ec21c9 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/magic_division.hpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" +#include "tuple.hpp" + +namespace ck { + +// magic number division +// Caution: +// 1. For uint32_t as dividend: magic number division implementation being used would produce +// correct result if the dividend is uint32_t and its value is within 31-bit value range. +// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been +// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number +// division implementation for uint32_t is then used. Therefore, dividend value need to be +// non-negative. +// TODO: +// 1. Implement magic number divison for int32_t +// 2. Implement magic number divison for unit32_t with 32-bit value range +struct MagicDivision +{ + // uint32_t + __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) + { + // WARNING: magic division is only applicable for division inside this range. + // You should use the return value of CalculateMagicNumbers, if division is not inside this + // range. The "else" logic below is to quiet down run-time error. + if(divisor >= 1 && divisor <= INT32_MAX) + { + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) + { + if((1U << shift) >= divisor) + { + break; + } + } + + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); + + return make_tuple(uint32_t(multiplier), shift); + } + else + { + return make_tuple(uint32_t(0), uint32_t(0)); + } + } + + __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<0>{}]; + } + + __host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<1>{}]; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[Number<0>{}]; + constexpr uint32_t shift = tmp[Number<1>{}]; + + return make_tuple(integral_constant{}, + integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); + + return integral_constant{}; + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); + + return integral_constant{}; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + return CalculateMagicNumbers(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + return CalculateMagicMultiplier(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + return CalculateMagicShift(integral_constant{}); + } + + // magic division for uint32_t + __device__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = __umulhi(dividend, multiplier); + return (tmp + dividend) >> shift; + } + + __host__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = static_cast(dividend) * multiplier >> 32; + return (tmp + dividend) >> shift; + } + + // magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + __device__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = __umulhi(dividend_u32, multiplier); + return (tmp + dividend_u32) >> shift; + } + + __host__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = static_cast(dividend_u32) * multiplier >> 32; + return (tmp + dividend_u32) >> shift; + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/math.hpp b/3rdparty/composable_kernel/include/ck/utility/math.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12203bd7f3146955b2568dcfb40c26403ff4b70a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/math.hpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { +namespace math { + +template +struct scales +{ + __host__ __device__ constexpr T operator()(T a) const { return s * a; } +}; + +template +struct plus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } +}; + +template +struct minus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } +}; + +struct multiplies +{ + template + __host__ __device__ constexpr auto operator()(const A& a, const B& b) const + { + return a * b; + } +}; + +template +struct maximize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + +template +struct minimize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + +template +struct integer_divide_ceiler +{ + __host__ __device__ constexpr T operator()(T a, T b) const + { + static_assert(is_same{} || is_same{}, "wrong type"); + + return (a + b - Number<1>{}) / b; + } +}; + +template +__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + +template +__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) +{ + return (x + y - Number<1>{}) / y; +} + +template +__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) +{ + return y * integer_divide_ceil(x, y); +} + +template +__host__ __device__ constexpr T max(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +__host__ __device__ constexpr index_t max(Number, index_t y) +{ + return X > y ? X : y; +} + +template +__host__ __device__ constexpr index_t max(index_t x, Number) +{ + return x > Y ? x : Y; +} + +template +__host__ __device__ constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return max(x, max(ys...)); +} + +template +__host__ __device__ constexpr T min(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +__host__ __device__ constexpr index_t min(Number, index_t y) +{ + return X < y ? X : y; +} + +template +__host__ __device__ constexpr index_t min(index_t x, Number) +{ + return x < Y ? x : Y; +} + +template +__host__ __device__ constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return min(x, min(ys...)); +} + +template +__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound) +{ + return min(max(x, lowerbound), upperbound); +} + +// disallow implicit type casting +template +__device__ T exp(T x); + +// TODO: add f16 support using v_exp_f16 + +template <> +__device__ float exp(float x) +{ + return __expf(x); +} + +template <> +__device__ double exp(double x) +{ + return exp(x); +} + +// greatest common divisor, aka highest common factor +__host__ __device__ constexpr index_t gcd(index_t x, index_t y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template +__host__ __device__ constexpr auto gcd(Number, Number) +{ + constexpr auto r = gcd(X, Y); + + return Number{}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +// least common multiple +template +__host__ __device__ constexpr auto lcm(X x, Y y) +{ + return (x * y) / gcd(x, y); +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto lcm(X x, Ys... ys) +{ + return lcm(x, lcm(ys...)); +} + +template +struct equal +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } +}; + +template +struct less +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } +}; + +} // namespace math +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/math_v2.hpp b/3rdparty/composable_kernel/include/ck/utility/math_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dc97666b375255849a4da2dd2efb911026b2d4ed --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/math_v2.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" + +namespace ck { +namespace math { + +// math functions for the host, some are implemented by calling C++ std functions + +static inline __host__ float abs(float x) { return std::abs(x); }; + +static inline __host__ double abs(double x) { return std::abs(x); }; + +static inline __host__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __host__ half_t abs(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + uint16_t abs_xx = xx & 0x7fff; + + half_t abs_x = ck::bit_cast(abs_xx); + + return abs_x; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __host__ int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + return (x ^ sgn) - sgn; +} +#endif + +static inline __host__ bool isnan(float x) { return std::isnan(x); }; + +static inline __host__ bool isnan(double x) { return std::isnan(x); }; + +static inline __host__ bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __host__ bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +static inline __host__ bool isnan(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + return (xx & 0x7FFF) > 0x7C00; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __host__ bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + +static inline __host__ float sqrt(float x) { return std::sqrt(x); }; + +static inline __host__ double sqrt(double x) { return std::sqrt(x); }; + +// math functions for the HIP kernel, some are implemented by calling hip builtin functions + +static inline __device__ float abs(float x) { return ::abs(x); }; + +static inline __device__ double abs(double x) { return ::abs(x); }; + +static inline __device__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __device__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __device__ int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + + return (x ^ sgn) - sgn; +}; +#endif + +static inline __device__ half_t abs(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + uint16_t abs_xx = xx & 0x7fff; + + half_t abs_x = ck::bit_cast(abs_xx); + + return abs_x; +}; + +static inline __device__ bool isnan(float x) { return ::isnan(x); }; + +static inline __device__ bool isnan(double x) { return ::isnan(x); }; + +static inline __device__ bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __device__ bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +static inline __device__ bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + +static inline __device__ bool isnan(half_t x) +{ + uint16_t xx = ck::bit_cast(x); + + return (xx & 0x7FFF) > 0x7C00; +}; + +static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; + +static inline __device__ double sqrt(double x) { return ::sqrt(x); }; + +} // namespace math +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/multi_index.hpp b/3rdparty/composable_kernel/include/ck/utility/multi_index.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1d544c0906cae3c61f4d7ce27e74e7636c3919b5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/multi_index.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "common_header.hpp" + +#if CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#include "array_multi_index.hpp" +#else +#include "statically_indexed_array_multi_index.hpp" +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/number.hpp b/3rdparty/composable_kernel/include/ck/utility/number.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3ca6b61dc6ac08330a6cf1633148bb9bad8cd81 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/number.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +using LongNumber = integral_constant; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/print.hpp b/3rdparty/composable_kernel/include/ck/utility/print.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eed1ca42c7346e32f4d6924452d568bf431d6c24 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/print.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_PRINT_HPP +#define CK_PRINT_HPP + +#include "array.hpp" +#include "statically_indexed_array.hpp" +#include "container_helper.hpp" +#include "sequence.hpp" + +namespace ck { + +template +__host__ __device__ void print_array(const char* s, T a) +{ + constexpr index_t nsize = a.Size(); + + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/reduction_common.hpp b/3rdparty/composable_kernel/include/ck/utility/reduction_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aceef7b296da7466c484aad8033fe1f949662fca --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/reduction_common.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_enums.hpp" + +namespace ck { + +struct float_equal_one +{ + template + __host__ __device__ inline bool operator()(T x) + { + return x <= static_cast(1.0f) and x >= static_cast(1.0f); + }; +}; + +struct float_equal_zero +{ + template + __host__ __device__ inline bool operator()(T x) + { + return x <= static_cast(0.0f) and x >= static_cast(0.0f); + }; +}; + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/reduction_enums.hpp b/3rdparty/composable_kernel/include/ck/utility/reduction_enums.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67856331059cef628a62e4b2223f262b839bbf02 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/reduction_enums.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +enum struct ReduceTensorOp +{ + ADD = 0, + MUL = 1, + MIN = 2, + MAX = 3, + AMAX = 4, + AVG = 5, + NORM1 = 6, + NORM2 = 7, + // MUL_NO_ZEROS = 8, +}; + +enum struct NanPropagation +{ + NOT_PROPAGATE_NAN = 0, + PROPAGATE_NAN = 1, +}; + +enum struct ReduceTensorIndices +{ + NO_INDICES = 0, + FLATTENED_INDICES = 1, +}; + +enum struct IndicesType +{ + INDICES_32BIT = 0, + INDICES_64BIT = 1, + INDICES_16BIT = 2, + INDICES_8BIT = 3, +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/reduction_functions_accumulate.hpp b/3rdparty/composable_kernel/include/ck/utility/reduction_functions_accumulate.hpp new file mode 100644 index 0000000000000000000000000000000000000000..724e5599d6c878ba08bf14429c40c0958ab9f2a2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/reduction_functions_accumulate.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" + +namespace ck { +namespace detail { + +// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs) +template +struct AccumulateWithNanIgnore +{ + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + if(!ck::math::isnan(currVal)) + { + ReduceOperation{}(accuVal, currVal); + } + }; +}; + +template +struct AccumulateWithNanCheck; + +// Does not check for NaN; does not guarantee NaNs be propagated to result +// e.g., given that max(a, b) = a > b ? a : b +// then max(NaN, 1) returns 1 +// max(1, NaN) returns NaN +// since any comparison involving NaNs returns false +template +struct AccumulateWithNanCheck +{ + // cppcheck-suppress constParameter + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + ReduceOperation{}(accuVal, currVal); + }; +}; + +// Check for NaN; guarantees NaNs be propagated to result +template +struct AccumulateWithNanCheck +{ + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + using ck::math::isnan; + + if(isnan(currVal)) + { + accuVal = currVal; + } + else + { + ReduceOperation{}(accuVal, currVal); + }; + }; +}; + +template +struct AccumulateWithIndexAndNanCheck; + +template +struct AccumulateWithIndexAndNanCheck +{ + __host__ __device__ static inline void + // cppcheck-suppress constParameter + Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) + { + bool changed = false; + + ReduceOperation{}(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + }; +}; + +template +struct AccumulateWithIndexAndNanCheck +{ + // The method is called when the ReduceOperation is indexable and the user asked for indices + __host__ __device__ static inline void Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) + { + using ck::math::isnan; + + if(isnan(currVal)) + { + accuVal = currVal; + accuIndex = currIndex; + } + else + { + bool changed = false; + + ReduceOperation{}(accuVal, currVal, changed); + + if(changed) + accuIndex = currIndex; + } + }; +}; + +} // namespace detail +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/reduction_operator.hpp b/3rdparty/composable_kernel/include/ck/utility/reduction_operator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..25ae8fd34faf5452a29fb7c8271de01bae9badda --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/reduction_operator.hpp @@ -0,0 +1,292 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" + +namespace ck { + +namespace reduce { + +// Every binary operator used in reduction is represented by a templated functor class. Each functor +// class must provide at least +// three members: +// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary +// operator, "identity element" is the unique +// element in the algebraic space that doesn't affect the value of other elements +// when operated against them, and the concept is similar to zero vector in +// vector space +// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). +// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this +// operator can use the InMemoryDataOperation to finalize, or else it return false +// 3) operator() -- the first argument of the operator must be both an input & output, and the +// corresponding variable usually stores +// the accumulated result of many operator() calls; the second argument is only an +// input. For indexable binary +// operator, the second version of operator() has third argument (which is an +// output) to indicate whether the +// accumulated value (the first argument) has changed, in which case the recorded +// accumulated index also need be +// changed. + +struct Add +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Add accumulator!"); + + a = a + b; + } +}; + +struct SquaredAdd +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the SquaredAdd accumulator!"); + + a = a + b * b; + } +}; + +struct Mul +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(1.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Mul accumulator!"); + + a = a * b; + } +}; + +struct Max +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return NumericLimits::Lowest(); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + if(a < b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + + if(a < b) + { + a = b; + changed = true; + } + } +}; + +struct Min +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return NumericLimits::Max(); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_min to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + + if(a > b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + + if(a > b) + { + a = b; + changed = true; + } + } +}; + +struct AMax +{ + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + + if(a < b) + a = b; + } + + template + __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + + if(a < b) + { + a = b; + changed = true; + } + } +}; + +template +constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) +{ + T result = ck::type_convert(0.0f); + + if(operation == InMemoryDataOperationEnum::AtomicMax) + result = ck::NumericLimits::Lowest(); + + return (result); +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = false; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value; +}; + +} // namespace reduce +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/sequence.hpp b/3rdparty/composable_kernel/include/ck/utility/sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..97b597221c2850d4b28b644266978fc56b295913 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/sequence.hpp @@ -0,0 +1,899 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +template +struct static_for; + +template +struct Sequence; + +template +struct sequence_split; + +template +struct sequence_reverse; + +template +struct sequence_map_inverse; + +template +struct is_valid_sequence_map; + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence); + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq); + +template +struct Sequence +{ + using Type = Sequence; + using data_type = index_t; + + static constexpr index_t mSize = sizeof...(Is); + + __host__ __device__ static constexpr auto Size() { return Number{}; } + + __host__ __device__ static constexpr auto GetSize() { return Size(); } + + __host__ __device__ static constexpr index_t At(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const index_t mData[mSize + 1] = {Is..., 0}; + return mData[I]; + } + + template + __host__ __device__ static constexpr auto At(Number) + { + static_assert(I < mSize, "wrong! I too large"); + + return Number{}; + } + + template + __host__ __device__ static constexpr auto Get(Number) + { + return At(Number{}); + } + + template + __host__ __device__ constexpr auto operator[](I i) const + { + return At(i); + } + + template + __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) + { + static_assert(sizeof...(Is) == sizeof...(IRs), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); + + return Sequence{})...>{}; + } + + // MapOld2New is Sequence<...> + template + __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) + { + static_assert(MapOld2New::Size() == Size(), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); + } + + __host__ __device__ static constexpr auto Reverse() + { + return typename sequence_reverse::type{}; + } + + __host__ __device__ static constexpr auto Front() + { + static_assert(mSize > 0, "wrong!"); + return At(Number<0>{}); + } + + __host__ __device__ static constexpr auto Back() + { + static_assert(mSize > 0, "wrong!"); + return At(Number{}); + } + + __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } + + __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } + + template + __host__ __device__ static constexpr auto PushFront(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushFront(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto Extract(Number...) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Extract(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Modify(Number, Number) + { + static_assert(I < Size(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); + + return seq_left.PushBack(Number{}).PushBack(seq_right); + } + + template + __host__ __device__ static constexpr auto Transform(F f) + { + return Sequence{}; + } + + __host__ __device__ static void Print() + { + printf("{"); + printf("size %d, ", index_t{Size()}); + static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); }); + printf("}"); + } +}; + +// merge sequence +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; + +template +struct sequence_merge, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_merge +{ + using type = Seq; +}; + +// generate sequence +template +struct sequence_gen +{ + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(Number{}); + using type = Sequence; + }; + + template + struct sequence_gen_impl + { + using type = Sequence<>; + }; + + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence +template +struct arithmetic_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = Sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename conditional::type; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t) const { return I; } + }; + + using type = typename sequence_gen::type; +}; + +// reverse inclusive scan (with init) sequence +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); + + using type = typename sequence_merge, old_scan>::type; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence<>; +}; + +// split sequence +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.Size(); + + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; + + using left_type = decltype(Seq::Extract(range0{})); + using right_type = decltype(Seq::Extract(range1{})); +}; + +// reverse sequence +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.Size(); + + using seq_split = sequence_split; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +#if 1 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + +template +struct sequence_sort_impl +{ + template + struct sorted_sequence_merge_impl + { + static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); + + static constexpr index_t chosen_value = + choose_left ? LeftValues::Front() : RightValues::Front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); + + using new_merged_values = decltype(MergedValues::PushBack(Number{})); + using new_merged_ids = decltype(MergedIds::PushBack(Number{})); + + using new_left_values = + typename conditional::type; + using new_left_ids = + typename conditional::type; + + using new_right_values = + typename conditional::type; + using new_right_ids = + typename conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + Sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::Size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = + typename conditional, Sequence>::type; + using sorted_ids = typename conditional, Sequence>::type; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + using sorted_values = Sequence; + using sorted_ids = Sequence; +}; + +template +struct sequence_sort_impl, Sequence<>, Compare> +{ + using sorted_values = Sequence<>; + using sorted_ids = Sequence<>; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::Front(); + static constexpr index_t current_id = RemainIds::Front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); + + using new_remain_values = decltype(RemainValues::PopFront()); + using new_remain_ids = decltype(RemainIds::PopFront()); + + using new_uniquified_values = + typename conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + Sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + Sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map : is_same::type, + typename sequence_sort>::type> +{ +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::Modify(X2Y::At(Number{}), Number{}); + + using type = + typename sequence_map_inverse_impl:: + type; + }; + + template + struct sequence_map_inverse_impl + { + using type = WorkingY2X; + }; + + using type = + typename sequence_map_inverse_impl::type, + 0, + SeqMap::Size()>::type; +}; + +template +__host__ __device__ constexpr bool operator==(Sequence, Sequence) +{ + return ((Xs == Ys) && ...); +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs + Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs - Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs * Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs / Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs % Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Number) +{ + return Sequence<(Xs + Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Number) +{ + return Sequence<(Xs - Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Number) +{ + return Sequence<(Xs * Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Number) +{ + return Sequence<(Xs / Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Number) +{ + return Sequence<(Xs % Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Number, Sequence) +{ + return Sequence<(Y + Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Sequence) +{ + return Sequence<(Y - Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Sequence) +{ + return Sequence<(Y * Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Sequence) +{ + return Sequence<(Y / Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Sequence) +{ + return Sequence<(Y % Xs)...>{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq) +{ + static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); + return sequence_pop_front(Seq::Reverse()).Reverse(); +} + +template +__host__ __device__ constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto +transform_sequences(F f, Sequence, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize && + Sequence::mSize == Sequence::mSize, + "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) +{ + return typename sequence_reverse_inclusive_scan::type{}; +} + +template +__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number{}) + .PushBack(Number{}); +} + +template +__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); +} + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) +{ + return Sequence{})...>{}; +} + +#if 1 +namespace detail { +template +struct pick_sequence_elements_by_mask_impl +{ + using new_work_seq = typename conditional::type; + + using type = + typename pick_sequence_elements_by_mask_impl::type; +}; + +template +struct pick_sequence_elements_by_mask_impl, Sequence<>> +{ + using type = WorkSeq; +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + static_assert(Seq::Size() == Mask::Size(), "wrong!"); + + return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; +} + +namespace detail { +template +struct modify_sequence_elements_by_ids_impl +{ + using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front())); + + using type = + typename modify_sequence_elements_by_ids_impl::type; +}; + +template +struct modify_sequence_elements_by_ids_impl, Sequence<>> +{ + using type = WorkSeq; +}; +} // namespace detail + +template +__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) +{ + static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!"); + + return typename detail::modify_sequence_elements_by_ids_impl::type{}; +} +#endif + +template +__host__ __device__ constexpr index_t +reduce_on_sequence(Seq, Reduce f, Number /*initial_value*/) +{ + index_t result = Init; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + result = f(result, Seq::At(i)); + } + + return result; +} + +// TODO: a generic any_of for any container +template +__host__ __device__ constexpr bool sequence_any_of(Seq, F f) +{ + bool flag = false; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag || f(Seq::At(i)); + } + + return flag; +} + +// TODO: a generic all_of for any container +template +__host__ __device__ constexpr bool sequence_all_of(Seq, F f) +{ + bool flag = true; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag && f(Seq::At(i)); + } + + return flag; +} + +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/sequence_helper.hpp b/3rdparty/composable_kernel/include/ck/utility/sequence_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db25c27e70c3653f94524367b8a6bce79480113e --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/sequence_helper.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/tuple.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_sequence(Number...) +{ + return Sequence{}; +} + +// F returns index_t +template +__host__ __device__ constexpr auto generate_sequence(F, Number) +{ + return typename sequence_gen::type{}; +} + +// F returns Number<> +template +__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +__host__ __device__ constexpr auto to_sequence(Tuple...>) +{ + return Sequence{}; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/span.hpp b/3rdparty/composable_kernel/include/ck/utility/span.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e501214547cfbcec25921d6526e77563504835a --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/span.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +namespace ck { + +template +class span +{ + public: + using element_type = T; + using value_type = std::remove_cv_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using const_iterator = pointer; + + constexpr span() : span(nullptr, size_type{0}) {} + + constexpr span(pointer first, size_type count) : ptr_(first), size_(count) {} + + constexpr span(pointer first, pointer last) : span(first, last - first) {} + + template + constexpr span(element_type (&arr)[N]) noexcept : span(arr, N) + { + } + + template + constexpr span(std::array& arr) noexcept : span(arr.data(), N) + { + } + + template + constexpr span(const Container& container) : span(container.data(), container.size()) + { + } + + constexpr iterator begin() const noexcept { return ptr_; } + constexpr const_iterator cbegin() const noexcept { return begin(); } + + constexpr iterator end() const noexcept { return begin() + size(); } + constexpr const_iterator cend() const noexcept { return end(); } + + constexpr reference front() const { return *begin(); } + constexpr reference back() const { return *(--end()); } + + constexpr reference operator[](size_type idx) const { return *(begin() + idx); } + constexpr pointer data() const noexcept { return ptr_; } + + constexpr size_type size() const noexcept { return size_; } + + private: + pointer ptr_; + size_type size_; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/static_buffer.hpp b/3rdparty/composable_kernel/include/ck/utility/static_buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd25c96203288cfc9aa0959e9a7b1d83ef4abfc2 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/static_buffer.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "statically_indexed_array.hpp" + +namespace ck { + +// static buffer for scalar +template // TODO remove this bool, no longer needed +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + template + __host__ __device__ constexpr StaticBuffer& operator=(const Tuple& y) + { + static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same"); + StaticBuffer& x = *this; + static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; }); + return x; + } + + __host__ __device__ constexpr StaticBuffer& operator=(const T& y) + { + StaticBuffer& x = *this; + static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; }); + return x; + } + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + // read access + template + __host__ __device__ constexpr const T& operator[](Number i) const + { + return base::operator[](i); + } + + // write access + template + __host__ __device__ constexpr T& operator()(Number i) + { + return base::operator()(i); + } + + __host__ __device__ void Set(T x) + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; }); + } + + __host__ __device__ void Clear() { Set(T{0}); } +}; + +// static buffer for vector +template ::value, bool>::type = false> +struct StaticBufferTupleOfVector + : public StaticallyIndexedArray, NumOfVector> +{ + using V = typename vector_type::type; + using base = StaticallyIndexedArray, NumOfVector>; + + static constexpr auto s_per_v = Number{}; + static constexpr auto num_of_v_ = Number{}; + static constexpr auto s_per_buf = s_per_v * num_of_v_; + + __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + + __host__ __device__ static constexpr index_t Size() { return s_per_buf; }; + + // Get S + // i is offset of S + template + __host__ __device__ constexpr const S& operator[](Number i) const + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator[](i_v).template AsType()[i_s]; + } + + // Set S + // i is offset of S + template + __host__ __device__ constexpr S& operator()(Number i) + { + constexpr auto i_v = i / s_per_v; + constexpr auto i_s = i % s_per_v; + + return base::operator()(i_v).template AsType()(i_s); + } + + // Get X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr auto GetAsType(Number i) const + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + return base::operator[](i_v).template AsType()[i_x]; + } + + // Set X + // i is offset of S, not X. i should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr void SetAsType(Number i, X x) + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + base::operator()(i_v).template AsType()(i_x) = x; + } + + // Get read access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr const auto& GetVectorTypeReference(Number i) const + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator[](i_v); + } + + // Get write access to vector_type V + // i is offset of S, not V. i should be aligned to V + template + __host__ __device__ constexpr auto& GetVectorTypeReference(Number i) + { + static_assert(i % s_per_v == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + + return base::operator()(i_v); + } + + __host__ __device__ void Clear() + { + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; + + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +template +__host__ __device__ constexpr auto make_static_buffer(LongNumber) +{ + return StaticBuffer{}; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array.hpp b/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3438776f413cf010664cc8aa4a18c09bf161fae7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { +template +struct tuple_concat; + +template +struct tuple_concat, Tuple> +{ + using type = Tuple; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = + typename tuple_concat::type, + typename StaticallyIndexedArrayImpl::type>::type; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple<>; +}; + +template +struct StaticallyIndexedArrayImpl +{ + using type = Tuple; +}; +} // namespace detail + +template +using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl::type; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +template +struct StaticallyIndexedArray_v2 +{ + __host__ __device__ constexpr StaticallyIndexedArray_v2() = default; + + __host__ __device__ static constexpr index_t Size() { return N; } + + // read access + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // write access + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < N, "wrong! out of range"); + + return data_[I]; + } + + // read access + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + // write access + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + T data_[N]; +}; + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array_multi_index.hpp b/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array_multi_index.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21b2941b21401ac561932a45ec304536365f2927 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP +#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = StaticallyIndexedArray; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_statically_indexed_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +// Here should use MultiIndex, instead of Tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template < + typename... Ys, + typename X, + enable_if_t::value && !std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template < + typename... Ys, + typename X, + enable_if_t::value && !std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); + return r; +} + +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); + return r; +} + +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); + return r; +} + +// MultiIndex = scalar * MultiIndex +template ::value || std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator*(Y a, const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; }); + return r; +} + +// MultiIndex = MultiIndex * scalar +template ::value || std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator*(const Tuple& x, Y a) +{ + return a * x; +} + +namespace mathext { + +template +__host__ __device__ constexpr auto exp(const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); }); + return r; +} + +template +__host__ __device__ constexpr auto max(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); }); + return r; +} + +} // namespace mathext + +template +__host__ __device__ void print_multi_index(const Tuple& x) +{ + printf("{"); + printf("MultiIndex, "); + printf("size %d,", index_t{sizeof...(Xs)}); + static_for<0, sizeof...(Xs), 1>{}( + [&](auto i) { printf("%d ", static_cast(x.At(i))); }); + printf("}"); +} + +} // namespace ck +#endif diff --git a/3rdparty/composable_kernel/include/ck/utility/synchronization.hpp b/3rdparty/composable_kernel/include/ck/utility/synchronization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec8b6892378dff1d8b4017f9ff6957a411ccd8ae --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/synchronization.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +__device__ void block_sync_lds() +{ +#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} +__device__ void s_nop() +{ +#if 1 + asm volatile("\ + s_nop 0 \n \ + " ::); +#else + __builtin_amdgcn_s_barrier(); +#endif +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/thread_group.hpp b/3rdparty/composable_kernel/include/ck/utility/thread_group.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d469dec899a556c5b8062efa775b7dc694500fe8 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/thread_group.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "get_id.hpp" + +namespace ck { + +template +struct ThisThreadBlock +{ + static constexpr index_t kNumThread_ = ThreadPerBlock; + + __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } + + __device__ static constexpr bool IsBelong() { return true; } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/transpose_vectors.hpp b/3rdparty/composable_kernel/include/ck/utility/transpose_vectors.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b0075d6005e3c5cf7d6772fda8132a6404878ed --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/transpose_vectors.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "statically_indexed_array.hpp" +#include "data_type.hpp" + +namespace ck { + +template ::value, bool>::type = false> +struct transpose_vectors; + +// transpose fp16 2x2 +__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) +{ +#if 0 + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + const vector_type vx0{x0}, vx1{x1}; + vector_type vy0, vy1; + + vy0.template AsType()(I0) = vx0.template AsType()[I0]; + vy0.template AsType()(I1) = vx1.template AsType()[I0]; + + vy1.template AsType()(I0) = vx0.template AsType()[I1]; + vy1.template AsType()(I1) = vx1.template AsType()[I1]; + + y0 = vy0.template AsType()[I0]; + y1 = vy1.template AsType()[I0]; +#else + constexpr int32_t m0 = 0x05040100; + constexpr int32_t m1 = 0x07060302; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + y0 = bit_cast(__builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m0)); + y1 = bit_cast(__builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m1)); +#endif +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = half_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // reference to 2 half2_t data from vx_tuple + const auto& x_s2_0 = vx_tuple[ix].template AsType()[iy / I2]; + const auto& x_s2_1 = vx_tuple[ix + I1].template AsType()[iy / I2]; + + // reference to 2 half2_t data from vy_tuple + auto& y_s2_0 = vy_tuple(iy).template AsType()(ix / I2); + auto& y_s2_1 = vy_tuple(iy + I1).template AsType()(ix / I2); + + // transpose + transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1); + }); + }); + } +}; + +// transpose int8 4x4 +__device__ void transpose_int8_4x4(const int8x4_t& x0, + const int8x4_t& x1, + const int8x4_t& x2, + const int8x4_t& x3, + int8x4_t& y0, + int8x4_t& y1, + int8x4_t& y2, + int8x4_t& y3) +{ + int32_t t0, t1; + int32_t z0, z1, z2, z3; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + t0 = __builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m0); + t1 = __builtin_amdgcn_perm(bit_cast(x3), bit_cast(x2), m0); + z0 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m1); + z1 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m2); + t0 = __builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m3); + t1 = __builtin_amdgcn_perm(bit_cast(x3), bit_cast(x2), m3); + z2 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m1); + z3 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m2); + + y0 = bit_cast(z0); + y1 = bit_cast(z1); + y2 = bit_cast(z2); + y3 = bit_cast(z3); +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = int8_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // reference to 4 int8 data from vx_tuple + const auto& x_s4_0 = vx_tuple[ix].template AsType()[iy / I4]; + const auto& x_s4_1 = vx_tuple[ix + I1].template AsType()[iy / I4]; + const auto& x_s4_2 = vx_tuple[ix + I2].template AsType()[iy / I4]; + const auto& x_s4_3 = vx_tuple[ix + I3].template AsType()[iy / I4]; + + // reference to 4 int8 data from vy_tuple + auto& y_s4_0 = vy_tuple(iy).template AsType()(ix / I4); + auto& y_s4_1 = vy_tuple(iy + I1).template AsType()(ix / I4); + auto& y_s4_2 = vy_tuple(iy + I2).template AsType()(ix / I4); + auto& y_s4_3 = vy_tuple(iy + I3).template AsType()(ix / I4); + + // transpose + transpose_int8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3); + }); + }); + } +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/tuple.hpp b/3rdparty/composable_kernel/include/ck/utility/tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d8664be550b79bc4de2edd224bc63350a0a1bde7 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/tuple.hpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/enable_if.hpp" + +namespace ck { + +namespace detail { + +template +struct TupleElementKey +{ + __host__ __device__ constexpr TupleElementKey() = default; +}; + +template +struct TupleElementKeyData +{ + using DataType = Data; + +#if 0 // workaround compiler complaint about implicitly-deleted default constructor + __host__ __device__ constexpr TupleElementKeyData() = default; +#else + __host__ __device__ constexpr TupleElementKeyData() : mData{} {} +#endif + + template , TupleElementKeyData>::value, + bool>::type = false> + __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward(v)) + { + } + + DataType mData; +}; + +// for read access of tuple element +template +__host__ __device__ constexpr const Data& +get_tuple_element_data_reference(const TupleElementKeyData& x) +{ + return static_cast(x.mData); +} + +// for write access of tuple element +template +__host__ __device__ constexpr Data& +get_tuple_element_data_reference(TupleElementKeyData& x) +{ + return x.mData; +} + +// TODO: not sure the use of reference is correct +template +__host__ __device__ constexpr Data&& +get_tuple_element_data_reference(TupleElementKeyData&& x) +{ + return static_cast(x.mData); +} + +// for infering type of tuple element +template +__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData& x) +{ + return std::forward(x.mData); +} + +template +struct TupleImpl; + +template +struct TupleImpl, Xs...> : TupleElementKeyData, Xs>... +{ + __host__ __device__ constexpr TupleImpl() = default; + + template , TupleImpl>::value, + bool>::type = false> + __host__ __device__ constexpr TupleImpl(Y&& y) + : TupleElementKeyData, Xs>(std::forward(y))... + { + } + + template = 2, bool>::type = false> + __host__ __device__ constexpr TupleImpl(Ys&&... ys) + : TupleElementKeyData, Xs>(std::forward(ys))... + { + static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), + "wrong! inconsistent size"); + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + template + __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey) const + { + return get_tuple_element_data_reference>(*this); + } + + template + __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) + { + return get_tuple_element_data_reference>(*this); + } +}; + +} // namespace detail + +template +struct Tuple : detail::TupleImpl::type, Xs...> +{ + using base = + detail::TupleImpl::type, Xs...>; + + __host__ __device__ constexpr Tuple() = default; + + template , Tuple>::value, + bool>::type = false> + __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) + { + } + + template = 2, bool>::type = + false> + __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) + { + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + // read access + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementDataByKey(detail::TupleElementKey{}); + } + + // write access + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementDataByKey(detail::TupleElementKey{}); + } + + // read access + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + // write access + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template <> +struct Tuple<> +{ + __host__ __device__ constexpr Tuple() = default; + + __host__ __device__ static constexpr index_t Size() { return 0; } + + template + __host__ __device__ constexpr auto operator=(const T&) + { + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +struct tuple_element +{ + // type should keep the cv/ref qualifier of original tuple element + using type = decltype(detail::get_tuple_element_data>(TTuple{})); +}; + +template +using tuple_element_t = typename tuple_element::type; + +template +__host__ __device__ constexpr auto make_tuple(Xs&&... xs) +{ + return Tuple...>(std::forward(xs)...); +} + +// https://en.cppreference.com/w/cpp/utility/tuple/tie +template +constexpr Tuple tie(Args&... args) noexcept +{ + return {args...}; +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/tuple_helper.hpp b/3rdparty/composable_kernel/include/ck/utility/tuple_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f5b142a5e7a765807bd9d7f556f7b8afc512d37 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/tuple_helper.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "functional4.hpp" +#include "tuple.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto generate_tuple(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +__host__ __device__ constexpr auto generate_tie(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return tie(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, + const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + +namespace detail { + +template +__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) +{ + return make_tuple(f(x.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); +} + +} // namespace detail + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/include/ck/utility/type.hpp b/3rdparty/composable_kernel/include/ck/utility/type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90b9df2950b7d979f8cef386d155c72f6d7a39a5 --- /dev/null +++ b/3rdparty/composable_kernel/include/ck/utility/type.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/enable_if.hpp" + +namespace ck { + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +inline constexpr bool is_same_v = is_same::value; + +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +template +using remove_cvref_t = remove_cv_t>; + +template +using remove_pointer_t = typename std::remove_pointer::type; + +template +inline constexpr bool is_pointer_v = std::is_pointer::value; + +template ::type = false> +__host__ __device__ constexpr Y bit_cast(const X& x) +{ +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST + Y y; + + __builtin_memcpy(&y, &x, sizeof(X)); + + return y; +#else + union AsType + { + X x; + Y y; + }; + + return AsType{x}.y; +#endif +} + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/CMakeLists.txt b/3rdparty/composable_kernel/library/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..90873fdd1489e4e09ec2d6d3ac5664180a4ff06e --- /dev/null +++ b/3rdparty/composable_kernel/library/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(src/tensor_operation_instance/gpu) +add_subdirectory(src/utility) diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46a1fa559a1a3d164acc99e536f7e26a83b97ed4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchedGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g_m_k_{a_g_m_k}, + b_g_k_n_{b_g_k_n}, + c_g_m_n_{c_g_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g_m_k_; + const Tensor& b_g_k_n_; + Tensor& c_g_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm::Argument; + + float Run(const Argument& arg) + { + auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { + const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_gmk_gkn_gmn, + arg.c_g_m_n_.mDesc.GetLengths()[0], + arg.c_g_m_n_.mDesc.GetLengths()[1], + arg.c_g_m_n_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g_m_k, + const Tensor& b_g_k_n, + Tensor& c_g_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b621e88a0c6e890bf1efb026615fcc63fadf6c9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/utility/math_v2.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + struct Argument : public device::BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array dxStrides, + const std::array dyStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const DyDataType* p_dy, + const ScaleDataType* p_scale, + const MeanVarDataType* p_savedMean, + const MeanVarDataType* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + DxDataType* p_dx, + DscaleDbiasDataType* p_dscale, + DscaleDbiasDataType* p_dbias) + : reduceDims_(reduceDims), + bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnDscaleDbiasStrides_(bnDscaleDbiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_dy_(p_dy), + p_scale_(p_scale), + p_savedMean_(p_savedMean), + p_savedInvVar_(p_savedInvVar), + dy_elementwise_op_(dy_elementwise_op), + p_dx_(p_dx), + p_dscale_(p_dscale), + p_dbias_(p_dbias) + { + using ck::host_common::get_index_set; + + if(std::any_of( + reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + // get invariant_dims[] and invariant_lengths[] + for(int dim = 0, i = 0; dim < Rank; dim++) + if(std::none_of( + reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; })) + { + invariantDims_[i] = dim; + invariant_lengths_[i] = xyLengths[dim]; + i++; + }; + + // get reduce_lengths_[] + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims[j]; + reduce_lengths_[i++] = xyLengths[dim]; + }; + + for(int i = 0; i < NumInvariantDim; i++) + if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i]) + throw std::runtime_error("Invalid lengths parameters!"); + + for(int j = 0, i = 0; j < NumInvariantDim; j++) + { + int dim = invariantDims_[j]; + x_invariant_strides_[i] = xStrides[dim]; + dy_invariant_strides_[i] = dyStrides[dim]; + dx_invariant_strides_[i] = dxStrides[dim]; + i++; + }; + + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims_[j]; + x_reduce_strides_[i] = xStrides[dim]; + dy_reduce_strides_[i] = dyStrides[dim]; + dx_reduce_strides_[i] = dxStrides[dim]; + i++; + }; + + reduceSize_ = std::accumulate( + reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies{}); + + invariant_index_set_ = get_index_set(invariant_lengths_); + reduce_index_set_ = get_index_set(reduce_lengths_); + + epsilon_ = type_convert(epsilon); + + haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr); + } + + std::array reduceDims_; + std::array invariantDims_; + std::array invariant_lengths_; + std::array reduce_lengths_; + + const std::array bnScaleBiasMeanVarLengths_; + const std::array bnScaleStrides_; + const std::array bnDscaleDbiasStrides_; + const std::array bnMeanVarStrides_; + + std::array x_invariant_strides_; + std::array dy_invariant_strides_; + std::array dx_invariant_strides_; + std::array x_reduce_strides_; + std::array dy_reduce_strides_; + std::array dx_reduce_strides_; + + const XDataType* p_x_; + const DyDataType* p_dy_; + const ScaleDataType* p_scale_; + const MeanVarDataType* p_savedMean_; + const MeanVarDataType* p_savedInvVar_; + const DyElementwiseOp dy_elementwise_op_; + + DxDataType* p_dx_; + DscaleDbiasDataType* p_dscale_; + DscaleDbiasDataType* p_dbias_; + + bool haveSavedMeanInvVar_; + + std::vector> invariant_index_set_; + std::vector> reduce_index_set_; + + AccDataType epsilon_; + size_t reduceSize_; + }; + + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + using ck::host_common::get_offset_from_index; + + auto thread_reduce_func = [&](auto invariant_index) { + size_t x_invariant_offset = get_offset_from_index( + arg.x_invariant_strides_, invariant_index); + size_t dy_invariant_offset = get_offset_from_index( + arg.dy_invariant_strides_, invariant_index); + size_t dx_invariant_offset = get_offset_from_index( + arg.dx_invariant_strides_, invariant_index); + + AccDataType mean = type_convert(0.0f); + AccDataType variance = type_convert(0.0f); + AccDataType invVar; + int32_t curr_count = 0; + + if(arg.haveSavedMeanInvVar_) + { + size_t mean_invVar_invariant_offset = get_offset_from_index( + arg.bnMeanVarStrides_, invariant_index); + + mean = + type_convert(arg.p_savedMean_[mean_invVar_invariant_offset]); + invVar = + type_convert(arg.p_savedInvVar_[mean_invVar_invariant_offset]); + } + else + { + // compute mean, variance using welford method + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + + curr_count++; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType delta = x - mean; + + mean += delta / curr_count; + + AccDataType delta2 = x - mean; + + variance += delta * delta2; + }; + + // actual variance + variance = variance / curr_count; + + // inv-variance defined as 1/sqrt(epsilon+variance) + invVar = + type_convert(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); + }; + + AccDataType dbias = + type_convert(0.0f); // Sum on reduced dimensions of dy + AccDataType dscale = + type_convert(0.0f); // Sum on reduced dimensions of dy * norm_x + + // 1) calculate dy * (x - mean) * inv-variance + // 2) calculate sum(dy) on reduced dimensions + // 3) calculate sum(dy * norm_x) on reduced dimensions + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t dy_reduce_offset = get_offset_from_index( + arg.dy_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto dy_offset = dy_invariant_offset + dy_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVar; + AccDataType dy = type_convert(arg.p_dy_[dy_offset]); + + arg.dy_elementwise_op_(dy, dy); + + dbias += dy; + dscale += norm_x * dy; + }; + + size_t dscale_offset = get_offset_from_index( + arg.bnDscaleDbiasStrides_, invariant_index); + size_t dbias_offset = get_offset_from_index( + arg.bnDscaleDbiasStrides_, invariant_index); + + arg.p_dscale_[dscale_offset] = type_convert(dscale); + arg.p_dbias_[dbias_offset] = type_convert(dbias); + + size_t scale_offset = + get_offset_from_index(arg.bnScaleStrides_, invariant_index); + + AccDataType scale = type_convert(arg.p_scale_[scale_offset]); + + AccDataType multiplier = type_convert(1.0f) / + type_convert(arg.reduceSize_) * invVar * + scale; + + // 1) calculate tmp = dscale * (x - mean) * inv-variance + // 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias + // - tmp) + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t dy_reduce_offset = get_offset_from_index( + arg.dy_reduce_strides_, reduce_index); + size_t dx_reduce_offset = get_offset_from_index( + arg.dx_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto dy_offset = dy_invariant_offset + dy_reduce_offset; + auto dx_offset = dx_invariant_offset + dx_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVar; + AccDataType dy = type_convert(arg.p_dy_[dy_offset]); + + arg.dy_elementwise_op_(dy, dy); + + AccDataType tmpVal = norm_x * dscale; + + AccDataType dx = multiplier * (type_convert(arg.reduceSize_) * dy - + dbias - tmpVal); + + arg.p_dx_[dx_offset] = type_convert(dx); + }; + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t work_per_thread = + (arg.invariant_index_set_.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t i_begin = it * work_per_thread; + std::size_t i_end = std::min(static_cast((it + 1) * work_per_thread), + arg.invariant_index_set_.size()); + + auto f = [=] { + for(std::size_t i = i_begin; i < i_end; ++i) + { + thread_reduce_func(arg.invariant_index_set_[i]); + } + }; + + threads[it] = joinable_thread(f); + } + + return (0.0f); + }; + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + }; + }; + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + (void)p_arg; + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dxStrides, + const std::array dyStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) override + { + return std::make_unique(xyLengths, + xStrides, + dxStrides, + dyStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnDscaleDbiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_dy), + static_cast(p_scale), + static_cast(p_savedMean), + static_cast(p_savedInvVar), + epsilon, + dy_elementwise_op, + static_cast(p_dx), + static_cast(p_dscale), + static_cast(p_dbias)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Reference_BatchNorm_Backward" << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd0db316804fd191ee73d9952f6df6779420910d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/utility/math_v2.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchNormFwd : public device::DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + struct Argument : public device::BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* bnScale, + const BiasDataType* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : reduceDims_(reduceDims), + bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + bnScale_(bnScale), + bnBias_(bnBias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + using ck::host_common::get_index_set; + + if(std::any_of( + reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + // get invariant_dims[] and invariant_lengths[] + for(int dim = 0, i = 0; dim < Rank; dim++) + if(std::none_of( + reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; })) + { + invariantDims_[i] = dim; + invariant_lengths_[i] = xyLengths[dim]; + i++; + }; + + // get reduce_lengths_[] + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims[j]; + reduce_lengths_[i++] = xyLengths[dim]; + }; + + for(int i = 0; i < NumInvariantDim; i++) + if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i]) + throw std::runtime_error("Invalid lengths parameters!"); + + for(int j = 0, i = 0; j < NumInvariantDim; j++) + { + int dim = invariantDims_[j]; + x_invariant_strides_[i] = xStrides[dim]; + y_invariant_strides_[i] = yStrides[dim]; + i++; + }; + + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims_[j]; + x_reduce_strides_[i] = xStrides[dim]; + y_reduce_strides_[i] = yStrides[dim]; + i++; + }; + + invariant_index_set_ = get_index_set(invariant_lengths_); + reduce_index_set_ = get_index_set(reduce_lengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr); + resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); + } + + std::array reduceDims_; + std::array invariantDims_; + std::array invariant_lengths_; + std::array reduce_lengths_; + + const std::array bnScaleBiasMeanVarLengths_; + const std::array bnScaleStrides_; + const std::array bnBiasStrides_; + const std::array bnMeanVarStrides_; + + std::array x_invariant_strides_; + std::array y_invariant_strides_; + std::array x_reduce_strides_; + std::array y_reduce_strides_; + + const XDataType* p_x_; + const ScaleDataType* bnScale_; + const BiasDataType* bnBias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + bool resultSave, resultRunning; + + std::vector> invariant_index_set_; + std::vector> reduce_index_set_; + + AccDataType averageFactor_; + AccDataType epsilon_; + }; + + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + using ck::host_common::get_offset_from_index; + + auto thread_reduce_func = [&](auto invariant_index) { + size_t x_invariant_offset = get_offset_from_index( + arg.x_invariant_strides_, invariant_index); + size_t y_invariant_offset = get_offset_from_index( + arg.y_invariant_strides_, invariant_index); + AccDataType mean = type_convert(0.0f); + AccDataType variance = type_convert(0.0f); + int32_t curr_count = 0; + + // compute mean, variance using welford method + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + + curr_count++; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType delta = x - mean; + + mean += delta / curr_count; + + AccDataType delta2 = x - mean; + + variance += delta * delta2; + }; + + // actual variance + variance = variance / curr_count; + + // inv-variance defined as 1/sqrt(epsilon+variance) + AccDataType invVariance = + type_convert(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); + + // save the mean/inv-variance if required + if(arg.resultSave) + { + size_t offset = get_offset_from_index(arg.bnMeanVarStrides_, + invariant_index); + + arg.resultSaveMean_[offset] = type_convert(mean); + arg.resultSaveInvVariance_[offset] = type_convert(invVariance); + }; + + // update the moving average if required + if(arg.resultRunning) + { + size_t offset = get_offset_from_index(arg.bnMeanVarStrides_, + invariant_index); + + AccDataType oneMinusAverageFactor = + type_convert(1.0) - arg.averageFactor_; + arg.resultRunningMean_[offset] = type_convert( + type_convert(arg.resultRunningMean_[offset]) * + oneMinusAverageFactor + + mean * arg.averageFactor_); + arg.resultRunningVariance_[offset] = type_convert( + arg.resultRunningVariance_[offset] * oneMinusAverageFactor + + variance * arg.averageFactor_); + }; + + size_t scale_offset = + get_offset_from_index(arg.bnScaleStrides_, invariant_index); + size_t bias_offset = + get_offset_from_index(arg.bnBiasStrides_, invariant_index); + + AccDataType scale = type_convert(arg.bnScale_[scale_offset]); + AccDataType bias = type_convert(arg.bnBias_[bias_offset]); + + // Normalization + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t y_reduce_offset = get_offset_from_index( + arg.y_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto y_offset = y_invariant_offset + y_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVariance; + + AccDataType y = scale * norm_x + bias; + + arg.y_elementwise_op_(y, y); + + arg.p_y_[y_offset] = type_convert(y); + }; + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t work_per_thread = + (arg.invariant_index_set_.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t i_begin = it * work_per_thread; + std::size_t i_end = std::min(static_cast((it + 1) * work_per_thread), + arg.invariant_index_set_.size()); + + auto f = [=] { + for(std::size_t i = i_begin; i < i_end; ++i) + { + thread_reduce_func(arg.invariant_index_set_[i]); + } + }; + + threads[it] = joinable_thread(f); + } + + return (0.0f); + }; + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + }; + }; + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + (void)p_arg; + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(bnScale), + static_cast(bnBias), + epsilon, + y_elementwise_op, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Reference_BatchNorm_Forward" << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..463c655ac1d1414f80f9d8b7fcfc696be6326018 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchNormInfer : public device::DeviceBatchNormInfer +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + struct Argument : public device::BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* bnScale, + const BiasDataType* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + const MeanVarDataType* estimatedMean, + const MeanVarDataType* estimatedVariance, + YDataType* p_y) + : reduceDims_(reduceDims), + bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + bnScale_(bnScale), + bnBias_(bnBias), + y_elementwise_op_(y_elementwise_op), + estimatedMean_(estimatedMean), + estimatedVariance_(estimatedVariance), + p_y_(p_y) + { + using ck::host_common::get_index_set; + + if(std::any_of( + reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + // get invariant_dims[] and invariant_lengths[] + for(int dim = 0, i = 0; dim < Rank; dim++) + if(std::none_of( + reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; })) + { + invariantDims_[i] = dim; + invariant_lengths_[i] = xyLengths[dim]; + i++; + }; + + // get reduce_lengths_[] + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims[j]; + reduce_lengths_[i++] = xyLengths[dim]; + }; + + // check invariant_lengths_ and bnScaleBiasMeanVarLengths + for(int i = 0; i < NumInvariantDim; i++) + if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i]) + throw std::runtime_error("Invalid lengths parameters!"); + + for(int j = 0, i = 0; j < NumInvariantDim; j++) + { + int dim = invariantDims_[j]; + x_invariant_strides_[i] = xStrides[dim]; + y_invariant_strides_[i] = yStrides[dim]; + i++; + }; + + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims_[j]; + x_reduce_strides_[i] = xStrides[dim]; + y_reduce_strides_[i] = yStrides[dim]; + i++; + }; + + invariant_index_set_ = get_index_set(invariant_lengths_); + reduce_index_set_ = get_index_set(reduce_lengths_); + + epsilon_ = type_convert(epsilon); + } + + std::array reduceDims_; + std::array invariantDims_; + std::array invariant_lengths_; + std::array reduce_lengths_; + + const std::array bnScaleBiasMeanVarLengths_; + const std::array bnScaleStrides_; + const std::array bnBiasStrides_; + const std::array bnMeanVarStrides_; + + std::array x_invariant_strides_; + std::array y_invariant_strides_; + std::array x_reduce_strides_; + std::array y_reduce_strides_; + + const XDataType* p_x_; + const ScaleDataType* bnScale_; + const BiasDataType* bnBias_; + const YElementwiseOp y_elementwise_op_; + + const MeanVarDataType* estimatedMean_; + const MeanVarDataType* estimatedVariance_; + + YDataType* p_y_; + + std::vector> invariant_index_set_; + std::vector> reduce_index_set_; + + AccDataType epsilon_; + }; + + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + using ck::host_common::get_offset_from_index; + + auto thread_reduce_func = [&](auto invariant_index) { + size_t x_invariant_offset = get_offset_from_index( + arg.x_invariant_strides_, invariant_index); + size_t y_invariant_offset = get_offset_from_index( + arg.y_invariant_strides_, invariant_index); + + size_t mean_variance_offset = + get_offset_from_index(arg.bnMeanVarStrides_, invariant_index); + + AccDataType mean = arg.estimatedMean_[mean_variance_offset]; + AccDataType variance = arg.estimatedVariance_[mean_variance_offset]; + + // inv-variance defined as 1/sqrt(epsilon+variance) + AccDataType invVariance = + type_convert(1.0f) / std::sqrt(arg.epsilon_ + variance); + + size_t scale_offset = + get_offset_from_index(arg.bnScaleStrides_, invariant_index); + size_t bias_offset = + get_offset_from_index(arg.bnBiasStrides_, invariant_index); + + AccDataType scale = type_convert(arg.bnScale_[scale_offset]); + AccDataType bias = type_convert(arg.bnBias_[bias_offset]); + + // normalization + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t y_reduce_offset = get_offset_from_index( + arg.y_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto y_offset = y_invariant_offset + y_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVariance; + + AccDataType y = scale * norm_x + bias; + + arg.y_elementwise_op_(y, y); + + arg.p_y_[y_offset] = type_convert(y); + }; + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t work_per_thread = + (arg.invariant_index_set_.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t i_begin = it * work_per_thread; + std::size_t i_end = std::min(static_cast((it + 1) * work_per_thread), + arg.invariant_index_set_.size()); + + auto f = [=] { + for(std::size_t i = i_begin; i < i_end; ++i) + { + thread_reduce_func(arg.invariant_index_set_[i]); + } + }; + + threads[it] = joinable_thread(f); + } + + return (0.0f); + }; + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + }; + }; + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + (void)p_arg; + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + const void* estimatedMean, + const void* estimatedVariance, + void* p_y) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(bnScale), + static_cast(bnBias), + epsilon, + y_elementwise_op, + static_cast(estimatedMean), + static_cast(estimatedVariance), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Reference_BatchNorm_Infer<" << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0149d88fdb062806d3b34b400f9b6d3389419f9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/library/utility/host_tensor.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// FIXME: support arbitrary elementwise operation for A/B/C +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct ReferenceCGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_real_{a_m_k_real}, + a_m_k_imag_{a_m_k_imag}, + b_k_n_real_{b_k_n_real}, + b_k_n_imag_{b_k_n_imag}, + c_m_n_real_{c_m_n_real}, + c_m_n_imag_{c_m_n_imag}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_real_; + const Tensor& a_m_k_imag_; + const Tensor& b_k_n_real_; + const Tensor& b_k_n_imag_; + Tensor& c_m_n_real_; + Tensor& c_m_n_imag_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceCGemm::Argument; + + float Run(const Argument& arg) + { + const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + + if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) + { + throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM"); + } + + auto f_mk_kn_mn_real = [&](auto m, auto n) { + float v_c_real = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag; + } + + arg.c_m_n_real_(m, n) = ck::type_convert(v_c_real); + }; + + auto f_mk_kn_mn_imag = [&](auto m, auto n) { + float v_c_imag = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real; + } + + arg.c_m_n_imag_(m, n) = ck::type_convert(v_c_imag); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn_real, + arg.c_m_n_real_.mDesc.GetLengths()[0], + arg.c_m_n_real_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_mk_kn_mn_imag, + arg.c_m_n_imag_.mDesc.GetLengths()[0], + arg.c_m_n_imag_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real, + c_m_n_imag, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceCGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..225f7b7e36f3a5a0d70058ed7089fc693a9e0493 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// input descriptor in [G, N, C, Do, Ho, Wo] order +// weight descriptor in [G, K, C, Z, Y, X] order +// output descriptor in [G, N, K, Di, Hi, Wi] order +// phyiscal layout is irrelavent +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdData : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + Tensor& input_; + const Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdData::Argument; + + float Run(const Argument& arg) + { + if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && + arg.weight_.GetNumOfDimension() == NDimSpatial + 3 && + arg.output_.GetNumOfDimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto f_ncw = [&](auto g, auto n, auto c, auto wi) { + std::size_t K = arg.weight_.GetLengths()[1]; + std::size_t X = arg.weight_.GetLengths()[3]; + std::size_t Wo = arg.output_.GetLengths()[3]; + + float v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(x * arg.conv_dilations_[0]); + + if(w_tmp % arg.conv_strides_[0] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(arg.conv_strides_[0]); + + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + float v_out = 0; + float v_wei = 0; + + arg.out_element_op_( + v_out, ck::type_convert(arg.output_(g, n, k, wo))); + + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(g, k, c, x))); + + v_acc += v_out * v_wei; + } + } + } + } + + float v_in; + + arg.in_element_op_(v_in, v_acc); + + arg.input_(g, n, c, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.input_.GetLengths()[0], + arg.input_.GetLengths()[1], + arg.input_.GetLengths()[2], + arg.input_.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 2) + { + auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) { + std::size_t K = arg.weight_.GetLengths()[1]; + std::size_t Y = arg.weight_.GetLengths()[3]; + std::size_t X = arg.weight_.GetLengths()[4]; + + std::size_t Ho = arg.output_.GetLengths()[3]; + std::size_t Wo = arg.output_.GetLengths()[4]; + + float v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = static_cast(hi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(y * arg.conv_dilations_[0]); + if(h_tmp % arg.conv_strides_[0] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(arg.conv_strides_[0]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + static_cast(wi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(x * arg.conv_dilations_[1]); + if(w_tmp % arg.conv_strides_[1] == 0) + { + auto wo = + static_cast(w_tmp) / + static_cast(arg.conv_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + float v_out = 0; + float v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert( + arg.output_(g, n, k, ho, wo))); + + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_(g, k, c, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + + float v_in; + + arg.in_element_op_(v_in, v_acc); + + arg.input_(g, n, c, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.input_.GetLengths()[0], + arg.input_.GetLengths()[1], + arg.input_.GetLengths()[2], + arg.input_.GetLengths()[3], + arg.input_.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 3) + { + auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) { + std::size_t K = arg.weight_.GetLengths()[1]; + std::size_t Z = arg.weight_.GetLengths()[3]; + std::size_t Y = arg.weight_.GetLengths()[4]; + std::size_t X = arg.weight_.GetLengths()[5]; + + std::size_t Do = arg.output_.GetLengths()[3]; + std::size_t Ho = arg.output_.GetLengths()[4]; + std::size_t Wo = arg.output_.GetLengths()[5]; + + float v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + auto d_tmp = static_cast(di) + + static_cast(arg.in_left_pads_[0]) - + static_cast(z * arg.conv_dilations_[0]); + if(d_tmp % arg.conv_strides_[0] == 0) + { + auto do_ = static_cast(d_tmp) / + static_cast(arg.conv_strides_[0]); + if(do_ >= 0 && ck::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = + static_cast(hi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(y * arg.conv_dilations_[1]); + if(h_tmp % arg.conv_strides_[1] == 0) + { + auto ho = + static_cast(h_tmp) / + static_cast(arg.conv_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast( + arg.in_left_pads_[2]) - + static_cast( + x * arg.conv_dilations_[2]); + + if(w_tmp % arg.conv_strides_[2] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast( + arg.conv_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + float v_out = 0; + float v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert(arg.output_( + g, n, k, do_, ho, wo))); + + arg.wei_element_op_( + v_wei, + ck::type_convert( + arg.weight_(g, k, c, z, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + } + } + } + + float v_in; + + arg.in_element_op_(v_in, v_acc); + + arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.input_.GetLengths()[0], + arg.input_.GetLengths()[1], + arg.input_.GetLengths()[2], + arg.input_.GetLengths()[3], + arg.input_.GetLengths()[4], + arg.input_.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdData" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7d62158f00c0aa28fdfe90d4a9df537367b22717 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// input descriptor in [G, N, C, Do, Ho, Wo] order +// weight descriptor in [G, K, C, Z, Y, X] order +// output descriptor in [G, N, K, Di, Hi, Wi] order +// phyiscal layout is irrelavent +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceConvBwdWeight : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{in_n_c_hi_wi}, + weight_{wei_k_c_y_x}, + output_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + Tensor& weight_; + const Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdWeight::Argument; + + float Run(const Argument& arg) + { + if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && + arg.weight_.GetNumOfDimension() == NDimSpatial + 3 && + arg.output_.GetNumOfDimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto f_kcx = [&](auto g, auto k, auto c, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n) + { + for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo) + { + auto wi = static_cast(wo * arg.conv_strides_[0]) + + static_cast(x * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[3]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, ck::type_convert(arg.output_(g, n, k, wo))); + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(g, n, c, wi))); + + v_acc += v_out * v_in; + } + } + } + + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(g, k, c, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcx, + arg.weight_.GetLengths()[0], + arg.weight_.GetLengths()[1], + arg.weight_.GetLengths()[2], + arg.weight_.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 2) + { + auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) { + std::size_t N = arg.output_.GetLengths()[1]; + + std::size_t Ho = arg.output_.GetLengths()[3]; + std::size_t Wo = arg.output_.GetLengths()[4]; + + float v_acc = 0; + + for(std::size_t n = 0; n < N; ++n) + { + for(std::size_t ho = 0; ho < Ho; ++ho) + { + auto hi = static_cast(ho * arg.conv_strides_[0]) + + static_cast(y * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + for(std::size_t wo = 0; wo < Wo; ++wo) + { + auto wi = + static_cast(wo * arg.conv_strides_[1]) + + static_cast(x * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + + if(hi >= 0 && + ck::type_convert(hi) < arg.input_.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[4]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, + ck::type_convert(arg.output_(g, n, k, ho, wo))); + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(g, n, c, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(g, k, c, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcyx, + arg.weight_.GetLengths()[0], + arg.weight_.GetLengths()[1], + arg.weight_.GetLengths()[2], + arg.weight_.GetLengths()[3], + arg.weight_.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 3) + { + auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n) + { + for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_) + { + auto di = static_cast(do_ * arg.conv_strides_[0]) + + static_cast(z * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho) + { + auto hi = + static_cast(ho * arg.conv_strides_[1]) + + static_cast(y * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo) + { + auto wi = + static_cast(wo * arg.conv_strides_[2]) + + static_cast(x * arg.conv_dilations_[2]) - + static_cast(arg.in_left_pads_[2]); + + if(di >= 0 && + ck::type_convert(di) < + arg.input_.GetLengths()[3] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.GetLengths()[4] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.GetLengths()[5]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(g, n, k, do_, ho, wo))); + + arg.in_element_op_(v_in, + ck::type_convert( + arg.input_(g, n, c, di, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + } + + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(g, k, c, z, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kczyx, + arg.weight_.GetLengths()[0], + arg.weight_.GetLengths()[1], + arg.weight_.GetLengths()[2], + arg.weight_.GetLengths()[3], + arg.weight_.GetLengths()[4], + arg.weight_.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdWeight" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b8d47d218b95e3ff2b38c2fc965b6bdc6a2afc68 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// +// @brief Reference implementation for forward convolution. +// +// @paragraph +// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order +// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout +// as long as dimensions in tensor descriptor is in GNCHW order +// +// @tparam InDataType Input tensor data type. +// @tparam WeiDataType Weights tensor data type. +// @tparam OutDataType Output tensor data type. +// @tparam InElementwiseOperation Functor for input tensor elementwise +// operation. +// @tparam WeiElementwiseOperation Functor for weights tensor elementwise +// operation. +// @tparam NDimSpatial Number of spatial dimensions. +// +// input descriptor in [G, N, C, Do, Ho, Wo] order +// weight descriptor in [G, K, C, Z, Y, X] order +// output descriptor in [G, N, K, Di, Hi, Wi] order +// phyiscal layout is irrelavent +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceConvFwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : input_{input}, + weight_{weight}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& input_; + const Tensor& weight_; + Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd::Argument; + + float Run(const Argument& arg) + { + if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && + arg.weight_.GetNumOfDimension() == NDimSpatial + 3 && + arg.output_.GetNumOfDimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto n, auto k, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c) + { + for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x) + { + auto wi = static_cast(wo * arg.conv_strides_[0]) + + static_cast(x * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(g, n, c, wi))); + + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(g, k, c, x))); + + v_acc += v_in * v_wei; + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + arg.output_(g, n, k, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(func, + arg.output_.GetLengths()[0], + arg.output_.GetLengths()[1], + arg.output_.GetLengths()[2], + arg.output_.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c) + { + for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y) + { + auto hi = static_cast(ho * arg.conv_strides_[0]) + + static_cast(y * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x) + { + auto wi = + static_cast(wo * arg.conv_strides_[1]) + + static_cast(x * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + + if(hi >= 0 && + ck::type_convert(hi) < arg.input_.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[4]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(g, n, c, hi, wi))); + + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(g, k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + arg.output_(g, n, k, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(func, + arg.output_.GetLengths()[0], + arg.output_.GetLengths()[1], + arg.output_.GetLengths()[2], + arg.output_.GetLengths()[3], + arg.output_.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c) + { + for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z) + { + auto di = static_cast(d_o * arg.conv_strides_[0]) + + static_cast(z * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y) + { + auto hi = + static_cast(ho * arg.conv_strides_[1]) + + static_cast(y * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x) + { + auto wi = + static_cast(wo * arg.conv_strides_[2]) + + static_cast(x * arg.conv_dilations_[2]) - + static_cast(arg.in_left_pads_[2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.GetLengths()[3] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.GetLengths()[4] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.GetLengths()[5]) + { + float v_in; + float v_wei; + + arg.in_element_op_(v_in, + ck::type_convert( + arg.input_(g, n, c, di, hi, wi))); + + arg.wei_element_op_( + v_wei, + ck::type_convert(arg.weight_(g, k, c, z, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + + arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(func, + arg.output_.GetLengths()[0], + arg.output_.GetLengths()[1], + arg.output_.GetLengths()[2], + arg.output_.GetLengths()[3], + arg.output_.GetLengths()[4], + arg.output_.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override + { + return NDimSpatial >= 1 && NDimSpatial <= 3; + } + + static auto MakeArgument(const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{input, + weight, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be22003fd90db20eaea80045bd34f908ed8ebb61 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template +struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc, static_cast(arg.bias_k_(k))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f949f27fde973cc44eec3162c7625980a9818116 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template +struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + resi_n_k_ho_wo_{resi_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + const Tensor& resi_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation_Add::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, + v_acc, + static_cast(arg.bias_k_(k)), + static_cast(arg.resi_n_k_ho_wo_(n, k, ho, wo))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation_Add" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6728bb1f471b4aad3f7ed0a2878fc7a81e1b569d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c77d22f4cd1c34d3fa81a973f115fc37b7d03cf0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBias2D : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c0_m_n_{c0_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& c0_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBias2D::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType a = 0; + AccDataType b = 0; + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + arg.a_element_op_(a, ck::type_convert(arg.a_m_k_(m, k))); + arg.b_element_op_(b, ck::type_convert(arg.b_k_n_(k, n))); + acc += a * b; + } + + CDataType cast_acc = static_cast(acc); + arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBias2D" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7dfc3c1ed4b868756277f43d341cca5a1796662e --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivation::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc, static_cast(arg.c0_n_(n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99102a40d4e9a4c6360b3e6ecf9c8a345f5b38a9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivationAdd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + c1_m_n_{c1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + const Tensor& c1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivationAdd::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, + v_acc, + static_cast(arg.c0_n_(n)), + static_cast(arg.c1_m_n_(m, n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivationAdd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..28132aa1ebd78178774caba73b91ba3c947b66be --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct ReferenceGemmLayernorm : public device::BaseOperator +{ + using ReferenceGemmInstance = ReferenceGemm; + + template + static void RunLayernorm(Tensor& result, + const Tensor& acc, // MxN + const Tensor& gamma, // 1xN + const Tensor& beta, // 1xN + const InDataType epsilon = 1e-5) + { + assert(acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] && + acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); + + size_t M = acc.mDesc.GetLengths()[0]; + size_t N = acc.mDesc.GetLengths()[1]; + + Tensor avg_acc_sq({M}); + Tensor avg_acc({M}); + Tensor acc_layernorm(acc); + + // reduce N dim + for(size_t i = 0; i < M; i++) + { + ComputeDataType sum_acc_sq = 0; + ComputeDataType sum_acc = 0; + for(size_t j = 0; j < N; j++) + { + sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j); + sum_acc += acc_layernorm(i, j); + } + avg_acc_sq(i) = sum_acc_sq / N; + avg_acc(i) = sum_acc / N; + } + + // normalize + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = + (self(idx[0], idx[1]) - avg_acc(idx[0])) / + sqrt(avg_acc_sq(idx[0]) - avg_acc(idx[0]) * avg_acc(idx[0]) + epsilon); + }); + + // affine + acc_layernorm.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) = self(idx[0], idx[1]) * gamma(idx[1]) + beta(idx[1]); + }); + + // cast + result = acc_layernorm.template CopyAsType(); + } + + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // MxN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_bias_{c0_n_bias}, + c0_m_n_add_{c0_m_n_add}, + c0_n_gamma_{c0_n_gamma}, + c0_n_beta_{c0_n_beta}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op}, + epsilon_{epsilon} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_bias_; + const Tensor& c0_m_n_add_; + const Tensor& c0_n_gamma_; + const Tensor& c0_n_beta_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + + const CDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + // using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + Tensor acc_m_n(arg.c_m_n_.mDesc); + acc_m_n.GenerateTensorValue(GeneratorTensor_1{0}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_, + arg.b_k_n_, + acc_m_n, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}); + + // gemm + ref_invoker.Run(ref_argument); + + // activation(acc + bias) + acc_m_n.ForEach([&](auto& self, auto idx) { + AccDataType out; + arg.acc_element_op_(out, acc_m_n(idx[0], idx[1]) + arg.c0_n_bias_(idx[1])); + self(idx[0], idx[1]) = out; + }); + + // add from other layers + acc_m_n.ForEach([&](auto& self, auto idx) { + self(idx[0], idx[1]) += arg.c0_m_n_add_(idx[0], idx[1]); + }); + + // layernorm + RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_gamma_, arg.c0_n_beta_); + + // elementwise op + arg.c_m_n_.ForEach([&](auto& self, auto idx) { + arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1])); + }); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n_bias, // 1xN + const Tensor& c0_m_n_add, // 1xN + const Tensor& c0_n_gamma, // 1xN + const Tensor& c0_n_beta, // 1xN + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + const CDataType epsilon = 1e-5) + { + return Argument{a_m_k, + b_k_n, + c_m_n, + c0_n_bias, + c0_m_n_add, + c0_n_gamma, + c0_n_beta, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + epsilon}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fedd4dce62cdcd31c2c582ab034291e9e098bc3c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGroupnorm : public device::BaseOperator +{ + // x = [N, H, W, G, C] + // y = [N, H, W, G, C] + // reduce dim [H, W, C], mean, var = [N, G] + // gamma, beta = [G, C] + // beta: [G, C] + struct Argument : public device::BaseArgument + { + Argument(const Tensor& x, + const Tensor& gamma, + const Tensor& beta, + Tensor& y, + AccElementwiseOperation acc_elementwise_op, + const std::vector lengths, + AccDataType epsilon) + : x_(x), + gamma_(gamma), + beta_(beta), + y_(y), + acc_elementwise_op_(acc_elementwise_op), + lengths_(lengths), + epsilon_(epsilon) + { + } + + const Tensor x_; + const Tensor gamma_; + const Tensor beta_; + Tensor& y_; + AccElementwiseOperation acc_elementwise_op_; + std::vector lengths_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int N = arg.lengths_[0]; + int H = arg.lengths_[1]; + int W = arg.lengths_[2]; + int G = arg.lengths_[3]; + int C = arg.lengths_[4]; + + Tensor mean({N, G}); + Tensor var({N, G}); + + // Compute mean & var in [H, W, C] by Welford Algorithm + // TODO - parallel for each HWC + // TODO - address calculation + for(int n = 0; n < N; ++n) + { + for(int g = 0; g < G; ++g) + { + AccDataType mean_val = type_convert(0.0f); + AccDataType var_val = type_convert(0.0f); + int32_t curr_count = 0; + + for(int h = 0; h < H; ++h) + { + for(int w = 0; w < W; ++w) + { + for(int c = 0; c < C; ++c) + { + curr_count++; + AccDataType x = type_convert(arg.x_(n, h, w, g, c)); + AccDataType delta = x - mean_val; + mean_val += delta / curr_count; + AccDataType delta2 = x - mean_val; + var_val += delta * delta2; + } + } + } + + mean(n, g) = mean_val; + var(n, g) = var_val / curr_count; + } + } + + // Normalization + for(int n = 0; n < N; ++n) + { + for(int h = 0; h < H; ++h) + { + for(int w = 0; w < W; ++w) + { + for(int g = 0; g < G; ++g) + { + for(int c = 0; c < C; ++c) + { + AccDataType x = type_convert(arg.x_(n, h, w, g, c)); + AccDataType gamma = type_convert(arg.gamma_(g, c)); + AccDataType beta = type_convert(arg.beta_(g, c)); + AccDataType mean_val = type_convert(mean(n, g)); + AccDataType var_val = type_convert(var(n, g)); + AccDataType y = gamma * (x - mean_val) / + ck::math::sqrt(arg.epsilon_ + var_val) + + beta; + arg.acc_elementwise_op_(y, y); + arg.y_(n, h, w, g, c) = type_convert(y); + } + } + } + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + if(p_arg_->lengths_.size() != 5) + return false; + + return true; + } + + static auto MakeArgument(const Tensor& x, + const Tensor& gamma, + const Tensor& beta, + Tensor& y, + AccElementwiseOperation acc_elementwise_op, + const std::vector lengths, + AccDataType epsilon) + { + return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..680d94f7d1dcc3c474ba665887a132693e323c0f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceLayernorm : public device::BaseOperator +{ + // TODO - support generic layernorm + static_assert((Rank == 2 && NumReduceDim == 1), "Only support 2D version so far"); + + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& x_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + Tensor& y_m_n, + AccElementwiseOperation acc_elementwise_op, + const std::vector lengths, + const std::vector reduceDims, + AccDataType epsilon) + : x_m_n_(x_m_n), + gamma_n_(gamma_n), + beta_n_(beta_n), + y_m_n_(y_m_n), + acc_elementwise_op_(acc_elementwise_op), + lengths_(lengths), + reduceDims_(reduceDims), + epsilon_(epsilon) + { + } + + const Tensor x_m_n_; + const Tensor gamma_n_; + const Tensor beta_n_; + Tensor& y_m_n_; + AccElementwiseOperation acc_elementwise_op_; + std::vector lengths_; + std::vector reduceDims_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int M = arg.lengths_[0]; + int N = arg.lengths_[1]; + + Tensor mean({M}); + Tensor var({M}); + + for(int m = 0; m < M; ++m) + { + mean(m) = 0; + var(m) = 0; + + for(int n = 0; n < N; ++n) + { + auto x_val = ck::type_convert(arg.x_m_n_(m, n)); + mean(m) += x_val; + var(m) += x_val * x_val; + } + + mean(m) = mean(m) / N; + var(m) = (var(m) / N) - (mean(m) * mean(m)); + } + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + auto x_val = ck::type_convert(arg.x_m_n_(m, n)); + auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_); + y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); + arg.acc_elementwise_op_(y_val, y_val); + arg.y_m_n_(m, n) = ck::type_convert(y_val); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + // TODO - support generic layernorm + if(p_arg_->lengths_.size() != 2) + return false; + + if(p_arg_->reduceDims_.size() != 1) + return false; + + if(p_arg_->reduceDims_[0] != 1) + return false; + + return true; + } + + static auto MakeArgument(const Tensor& x_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + Tensor& y_m_n, + AccElementwiseOperation acc_elementwise_op, + const std::vector lengths, + const std::vector reduceDims, + AccDataType epsilon) + { + return Argument{ + x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4839eb8ade3b13c337e65f0dbc4aa6a845ee4e9a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSoftmax : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) + { + // std::cout << "debug: scalar dims: "; + for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) + { + if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == + sm_reduce_dims.end()) + { + sm_scalar_dims_.push_back(i); + // std::cout << i << ", "; + } + } + // std::cout << std::endl; + } + + const Tensor& in_; + Tensor& out_; + AccDataType alpha_; + AccDataType beta_; + std::vector sm_reduce_dims_; + std::vector sm_scalar_dims_; // dim after internal max/sum reduction + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + std::vector scalar_lengths; + for(index_t dim : arg.sm_scalar_dims_) + { + scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); + } + // max and sum reduction with final reduced values of dim=0 is a scalar so give it + // appropriate lengths of {1} + if(arg.sm_scalar_dims_.size() == 0) + { + scalar_lengths.push_back(1); + } + + Tensor reduce_max(scalar_lengths); + reduce_max.GenerateTensorValue( + GeneratorTensor_1{std::numeric_limits::lowest()}); + Tensor reduce_sum(scalar_lengths); + reduce_sum.GenerateTensorValue(GeneratorTensor_1{0}); + + // when final reduced values is of dim=0, the index will be transformed into empty + // std::vector which is actually a valid input for Tensor::operator(std::vector) and + // internally accesses 0'th element + auto to_sm_scalar_idx = [&](auto idx) { + std::vector sm_scalar_idx; + for(index_t dim : arg.sm_scalar_dims_) + { + sm_scalar_idx.push_back(idx[dim]); + } + return sm_scalar_idx; + }; + + arg.in_.ForEach([&](auto& self, auto idx) { + reduce_max(to_sm_scalar_idx(idx)) = std::max( + reduce_max(to_sm_scalar_idx(idx)), ck::type_convert(self(idx))); + }); + + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; + + Tensor in_stable(arg.in_.mDesc); + in_stable.ForEach([&](auto& self, auto idx) { + // numerator = exp(x - max(x)) + self(idx) = std::exp(ck::type_convert(arg.in_(idx)) - + reduce_max(to_sm_scalar_idx(idx))); + }); + + // LogRangeAsType(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl; + + in_stable.ForEach([&](auto& self, auto idx) { + // denominator = sum(exp(x - max(x))) + reduce_sum(to_sm_scalar_idx(idx)) += self(idx); + }); + + // LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") << + // std::endl; + + arg.out_.ForEach([&](auto& self, auto idx) { + AccDataType temp_result = + arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + + arg.beta_ * self(idx); + self(idx) = ck::type_convert(temp_result); + }); + + // LogRangeAsType(std::cout << "out: ", arg.out_.mData, ",") << std::endl; + // reduction along reduce dims + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") + // << std::endl; + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const std::vector sm_reduce_dims) + { + return Argument{in, out, alpha, beta, sm_reduce_dims}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSoftmax" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6a9b0fb5ee4a3135750dad0d0818a0f6e439f3a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSparseEmbedding3ForwardLayernorm : public device::BaseOperator +{ + struct Argument : public device::BaseArgument + { + Argument(Tensor& output, + const Tensor& emb_a, + const Tensor& emb_b, + const Tensor& emb_c, + const Tensor& index_a, + const Tensor& index_b, + const Tensor& index_c, + const Tensor& gamma, + const Tensor& beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + AccDataType epsilon) + : output_(output), + emb_a_(emb_a), + emb_b_(emb_b), + emb_c_(emb_c), + index_a_(index_a), + index_b_(index_b), + index_c_(index_c), + gamma_(gamma), + beta_(beta), + NumRows_(NumRows), + EmbeddingDim_(EmbeddingDim), + IndexLength_(IndexLength), + epsilon_(epsilon) + { + } + Tensor& output_; + const Tensor emb_a_; + const Tensor emb_b_; + const Tensor emb_c_; + const Tensor index_a_; + const Tensor index_b_; + const Tensor index_c_; + const Tensor gamma_; + const Tensor beta_; + ck::index_t NumRows_; + ck::index_t EmbeddingDim_; + ck::index_t IndexLength_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + ck::index_t D = arg.EmbeddingDim_; + ck::index_t L = arg.IndexLength_; + ck::index_t E = arg.NumRows_; + + Tensor accumulator({L, D}); + + Tensor mean({L}); + Tensor var({L}); + + accumulator.SetZero(); + + auto f_emb_per_row = [&](auto idx) { + IndexType idx_a = arg.index_a_(idx); + IndexType idx_b = arg.index_b_(idx); + IndexType idx_c = arg.index_c_(idx); + + if(!((idx_a < E) && (idx_b < E) && (idx_c < E))) + { + throw(std::runtime_error("wrong! out of range")); + } + + for(auto d = 0; d < D; d++) + { + auto v_a = ck::type_convert(arg.emb_a_(idx_a, d)); + auto v_b = ck::type_convert(arg.emb_b_(idx_b, d)); + auto v_c = ck::type_convert(arg.emb_c_(idx_c, d)); + + accumulator(idx, d) += v_a + v_b + v_c; + } + }; + make_ParallelTensorFunctor(f_emb_per_row, L)(std::thread::hardware_concurrency()); + + // layernorm + for(auto idx = 0; idx < L; ++idx) + { + mean(idx) = 0; + var(idx) = 0; + + for(auto d = 0; d < D; ++d) + { + auto x_val = accumulator(idx, d); + mean(idx) += x_val; + var(idx) += x_val * x_val; + } + + mean(idx) = mean(idx) / D; + var(idx) = (var(idx) / D) - (mean(idx) * mean(idx)); + } + + for(auto idx = 0; idx < L; ++idx) + { + for(auto d = 0; d < D; ++d) + { + auto x_val = accumulator(idx, d); + auto y_val = (x_val - mean(idx)) / sqrt(var(idx) + arg.epsilon_); + y_val = (y_val * arg.gamma_(d)) + arg.beta_(d); + arg.output_(idx, d) = ck::type_convert(y_val); + } + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& output, + const Tensor& emb_a, + const Tensor& emb_b, + const Tensor& emb_c, + const Tensor& index_a, + const Tensor& index_b, + const Tensor& index_c, + const Tensor& gamma, + const Tensor& beta, + ck::index_t NumRows, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + AccDataType epsilon) + { + return Argument(output, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + NumRows, + EmbeddingDim, + IndexLength, + epsilon); + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSparseEmbedding3ForwardLayernorm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df4fca6562755012200ff888677bcac1c90129a2 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef NAIVE_CONV_FWD_HPP +#define NAIVE_CONV_FWD_HPP + +namespace ck { +namespace ref { + +/* + * \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK). + * + * \param N number of batches + * \param K number of filters + * \param C number of channels of weight + * \param (Di, Hi, Wi) depth, height and width dimension of data + * \param (Z, Y, X) depth, height and width dimensions of weights + * \param (Do, Ho, Wo) depth, height and width dimension of output + * \param (stride_z, stride_y, stride_x) strides + * \param (dilation_z, dilation_y, dilation_x) dilations + * \param (pad_z, pad_y, pad_x) pads + */ +template +__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in, + const TWei* __restrict__ p_wei, + TOut* __restrict__ p_out, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x) +{ + const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const index_t num_threads = blockDim.x * gridDim.x; + const long_index_t output_length = N * Do * Ho * Wo * K; + + const index_t out_strides[] = {Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K}; + const index_t in_strides[] = {Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C}; + const index_t wei_strides[] = {Z * Y * X * C, Y * X * C, X * C, C}; + + constexpr auto in_op = InElementwiseOperation{}; + constexpr auto wei_op = WeiElementwiseOperation{}; + constexpr auto out_op = OutElementwiseOperation{}; + + TIn in_val; + TWei wei_val; + TOut out_val; + + for(long_index_t ii = tid; ii < output_length; ii += num_threads) + { + const index_t n = ii / out_strides[0]; + index_t k = ii - n * out_strides[0]; + const index_t dO = k / out_strides[1]; + k -= dO * out_strides[1]; + const index_t ho = k / out_strides[2]; + k -= ho * out_strides[2]; + const index_t wo = k / out_strides[3]; + k -= wo * out_strides[3]; + + TAcc acc = static_cast(0); + + const TIn* in_n = p_in + static_cast(n) * in_strides[0]; + const TWei* wei_k = p_wei + static_cast(k) * wei_strides[0]; + + for(index_t z = 0; z < Z; ++z) + { + index_t di = stride_z * dO - pad_z + dilation_z * z; + const TIn* in_n_di = in_n + di * in_strides[1]; + const TWei* wei_k_z = wei_k + z * wei_strides[1]; + + for(index_t y = 0; y < Y; ++y) + { + index_t hi = stride_y * ho - pad_y + dilation_y * y; + const TIn* in_n_di_hi = in_n_di + hi * in_strides[2]; + const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2]; + + for(index_t x = 0; x < X; ++x) + { + index_t wi = stride_x * wo - pad_x + dilation_x * x; + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3]; + const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3]; + + if(di >= 0 && di < Di && hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) + { + for(index_t c = 0; c < C; ++c) + { + in_op(in_val, in_n_di_hi_wi[c]); + wei_op(wei_val, wei_k_z_y_x[c]); + acc += in_val * wei_val; + } + } + } + } + } + + out_op(out_val, static_cast(acc)); + p_out[ii] = out_val; + } +} +} // namespace ref +} // namespace ck + +#endif diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..20df1b3616a016029f4083bd821b4e286b7b727c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/functional2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +void add_device_operation_instances(std::vector>& op_instances, + const NewOpInstances& new_op_instances) +{ + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + const auto new_op_instance = std::get(new_op_instances); + + using NewOpInstance = remove_cvref_t; + + static_assert(std::is_base_of_v, + "wrong! NewOpInstance should be derived from BaseOp"); + + op_instances.push_back(std::make_unique(new_op_instance)); + }); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91980a9a66c6f57cdb8557f8ee6306abf31bf77d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// aliasing, for commonly used data type +using F64 = double; +using F32 = float; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +using F16_Tuple = ck::Tuple; +using F16_F16_Tuple = ck::Tuple; + +using F32_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; +using I32_F32_Tuple = ck::Tuple; + +// GEMM layout +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Row_Tuple = ck::Tuple; +using Row_Row_Tuple = ck::Tuple; + +// Conv layout +// +using NWC = ck::tensor_layout::convolution::NWC; +using NHWC = ck::tensor_layout::convolution::NHWC; +using NDHWC = ck::tensor_layout::convolution::NDHWC; + +using KXC = ck::tensor_layout::convolution::KXC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; + +using NWK = ck::tensor_layout::convolution::NWK; +using NHWK = ck::tensor_layout::convolution::NHWK; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +// +using GNWC = ck::tensor_layout::convolution::GNWC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + +using GKXC = ck::tensor_layout::convolution::GKXC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + +using GNWK = ck::tensor_layout::convolution::GNWK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +// +using NWGC = ck::tensor_layout::convolution::NWGC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + +using KXGC = ck::tensor_layout::convolution::KXGC; +using KYXGC = ck::tensor_layout::convolution::KYXGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + +using NWGK = ck::tensor_layout::convolution::NWGK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +// +using GK = ck::tensor_layout::convolution::G_K; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; + +// pointwise functor +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; +using Scale = ck::tensor_operation::element_wise::Scale; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +template +using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +template +using Add_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +template +using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +template +using Add_Activation_Mul2_Clamp = + ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +template +struct DeviceOperationInstanceFactory; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0655fd92e44c5c6400c5a66e6ee9e21f6e50da71 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>>& + instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceBatchedGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..495c5f884fd633898867a3e0a5fc121326527c8e --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector, + Row, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& + instances); + +void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector, + Col, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD> +{ + using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6dcfa30d3efd2d78562c1829ffb72b4eaabca9b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); + +void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances); +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmGemm> +{ + using DeviceOp = DeviceBatchedGemmGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a0b1b1fa7de54954a5bc375241bb7a0686b2e49 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); + +void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm> +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(MaskOutUpperTriangle) + { + add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else + { + add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89df1a7a0d44394d190973ad787b176da8002851 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances); + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances); + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances); + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpec>> +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpec>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + else if(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + else if(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c84ffcff8cbb8d4820f88dbfa3e7de6a04961bb8 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_batchnorm_backward_rank_4_3_f16_instances( + std::vector>>&); + +// FP32 +void add_device_batchnorm_backward_rank_4_3_f32_instances( + std::vector>>&); + +// BF16 +void add_device_batchnorm_backward_rank_4_3_bf16_instances( + std::vector>>&); + +// FP64 +void add_device_batchnorm_backward_rank_4_3_f64_instances( + std::vector>>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormBwd> +{ + using DeviceOp = DeviceBatchNormBwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8e40d60c17b8dc8d743e29cc1472a1e9eed9bdc1 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_batchnorm_forward_rank_4_3_f16_instances( + std::vector< + std::unique_ptr>>&); + +// FP32 +void add_device_batchnorm_forward_rank_4_3_f32_instances( + std::vector< + std::unique_ptr>>&); + +// BF16 +void add_device_batchnorm_forward_rank_4_3_bf16_instances( + std::vector< + std::unique_ptr>>&); + +// FP64 +void add_device_batchnorm_forward_rank_4_3_f64_instances( + std::vector< + std::unique_ptr>>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormFwd> +{ + using DeviceOp = DeviceBatchNormFwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0cea7e390aab9402db40a439ffb43ccfb335bde --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances); + +// Contraction + Bilinear +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e921ecd47aa5b08d2f6f63ee2bad30c40608b05c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances); + +// Contraction + Scale +template +struct DeviceOperationInstanceFactory, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>> +{ + using DeviceOp = DeviceContractionMultipleD, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec5d18fc214c34c1e972aa5fc50726cd41f8180c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward data +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector>>& + instances); + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector>>& instances); + +// conv2d backward data +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& instances); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& instances); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances); + +// conv2d dl +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& instances); + +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& instances); + +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances); +// conv3d backward data +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>>& instances); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>>& instances); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceConvBwdData; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); + add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62f28c9b11d287f355d5feeeb24bf7863ac96851 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d forward +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances); + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceConvFwd> +{ + using DeviceOp = DeviceConvFwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..141af558471852471833459edb1b9a07e1108069 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::DeviceElementwiseBasePtr< + Tuple, + Tuple, + Normalize, + 2>; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector& instances); + +template +auto get_device_normalize_from_mean_meansquare_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value && is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..682f54675982e9338823a814c46075242534de1d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +template +auto get_device_gemm_add_add_mean_squaremean_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c87ae159bee2635c903490b9e09096be8c242c8c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_elementwise_normalization_rank_2_1_f16_instances( + std::vector, + F16, + F16, + F32, + F16, + element_wise::Add, + PassThrough, + 2, + 1>>>&); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceElementwiseNormalization; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_elementwise_normalization_rank_2_1_f16_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e230507e7e30a605bac4ed5d27cea454739ba24a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemm> +{ + using DeviceOp = DeviceGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..09d8e8b95bb28f393e8d9b467363b0ccb78a1124 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + Add + FastGelu +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..554437f4903bcae96f68f2a53097a6eaf88367f3 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + FastGelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ef70504f29b717ed70a494926d256f8b98a35e47 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances); + +// GEMM + Bilinear +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbc5df98a4ec116976d491835e8a8411b609c992 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>&); + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>&); + +// GEMM + FastGelu +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemmMultipleD; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8986a793444c8f73679d892f4d55adb02bc61834 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmSplitK> +{ + using DeviceOp = DeviceGemmSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..81b2b4fcf37cb1f0d05415eb7d56587f0a29f380 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d backward data +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD< + NumDimSpatial, + OutLayout, + WeiLayout, + Empty_Tuple, + InLayout, + OutDataType, + WeiDataType, + Empty_Tuple, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = + DeviceGroupedConvBwdDataMultipleD; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ef6920e52a2096a3f61e40435c6fa3ce64a17323 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +// conv2d backward weight +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +// conv3d backward weight +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedConvBwdWeight; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 1 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } + } + else if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee38b738274c5492f27ed0e7ad5c285eacf1f54b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -0,0 +1,396 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv1d forward, GNWC/GKXC/GNWK +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances); + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( + std::vector>>& instances); +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 1 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + // no instance + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + // no instance + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + // no instance + } + } + else if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c64598daddb6874bd9c01eba6045b42c91792d14 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55c67b7623b85c03381833431338f6be083fbd7b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_normalization_rank_2_1_f16_instances( + std::vector>>&); + +void add_device_normalization_rank_4_3_f16_instances( + std::vector>>&); + +void add_device_normalization_rank_5_3_f16_instances( + std::vector>>&); + +// FP32 +void add_device_normalization_rank_2_1_f32_instances( + std::vector>>&); + +void add_device_normalization_rank_4_3_f32_instances( + std::vector>>&); + +void add_device_normalization_rank_5_3_f32_instances( + std::vector>>&); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceNormalization; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_normalization_rank_2_1_f16_instances(op_ptrs); + } + else if constexpr(Rank == 4 && NumReduceDim == 3) + { + add_device_normalization_rank_4_3_f16_instances(op_ptrs); + } + else if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_normalization_rank_2_1_f32_instances(op_ptrs); + } + else if constexpr(Rank == 4 && NumReduceDim == 3) + { + add_device_normalization_rank_4_3_f32_instances(op_ptrs); + } + else if constexpr(Rank == 5 && NumReduceDim == 3) + { + add_device_normalization_rank_5_3_f32_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eda81a233c01c7c2bff96a873a3db823f49f572c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_bias_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..113840263893688f2adfc17060a77428628ad6bd --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_bias_perlayer_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a67ce56880d5ce21f0116ce25a08b53c474fc6d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..410be4a57bfce65c9c6877ac8508d3cd2921f69f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_perlayer_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_relu_perlayer_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..550a7b034509af43412c0faf5ac2fb52fb8b2eb3 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp" diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90cfe837df62e145105f964967e5180ce6b2e768 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using reduce_configuration_1_instances_blockwise = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_blockwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +void add_device_reduce_instance_blockwise( + std::vector>& + device_op_instances) +{ + static_for<0, std::tuple_size::value, 1>{}( + [&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_blockwise{}))>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; + + using ReduceOpInstance = + DeviceReduceMultiBlock; + + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..521d93e6001dbf2662ec52c2955416d94f6b10e5 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fe3fd6c0a7b77620654577dbfa4774ffabe01cb0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..52a2b69cdd2175713f5e08a1cbbfe6f2da59360f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee4fee41ea4c38730f9c494ccc3774672856e8cc --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3abdb7f9588bd11686146980bfc3c1bd333a5aed --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0dbcf31dd865529e702333188af0bbfa68d9bfe --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7bbf3df0a375af62c1ecdbbc47e13fb0e6779e78 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..559f322261e7a7ffcb8473058a02623b0dedf197 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..28c96107893f5be14474f883a455c601000fd418 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5080d2863644f37b71238ca78198c3150575aa71 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0d24d15371d02e7f6492f93ee3da7f7ada6073d6 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c806e807c8ed400c59460c72666c4e9277b0ded9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b7c046e751f8816b3ccf1bedf806391b71afc491 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..771bec1c95b9ba4e47c5741dafe42104a5abcaa3 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1fe8addba117e9b92de6fad69220263e8176328 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6bc0662fea97b0221e20201a178423a6524b9935 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f8005132decf43667334cf1c39f38c8b306af35 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c771ac4fab9beeee4a1b3516d1fcfd0fd62bd2b2 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9ddbb9aea2e14043a948a3b322b8a7cf7330ecc --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..390a719ceb1951369b8cd628ced732aed640b87b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2a9ddbc61b3c21713bc50d370ff040efea46435d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5746884442880571f8e00f49e5d4144adb951c60 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ad0f2357e05728179c8c75d22808a125a3cf9340 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7d95276380553f4c3da54813329097a901b13c0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec56229937a850d2f2708780decc2e024eb8e9f0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48f66da659b1e39bf64344a1a9339b034207a3ff --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fabfa5b4c6f7a335dd007ad95137bae5de876b8a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e08faec2000ee51745131297fcb4d75ea29c46e4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a1e692aae38ade1a30b96873340aff721e11984b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9654e8cceba39d9115d2cf2222dc3d72b75d49b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..78244213097ba7c11168b2f0a7c949133daaac3f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df323d40b39649599807bf1e303135d0cbe551cf --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +extern template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c08e5ef2f0365e36125d141f95474573dce5bf6 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct ReductionConfiguration_1 +{ + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid Configuration!"); + + static constexpr int BlockSize_ = BlockSize; + static constexpr int MThreadClusterSize_ = MThreadClusterSize; + static constexpr int KThreadClusterSize_ = KThreadClusterSize; +}; + +template +struct ReductionConfiguration_2 +{ + static constexpr int InSrcVectorDim_ = InSrcVectorDim; + static constexpr int InSrcVectorSize_ = InSrcVectorSize; + static constexpr int OutDstVectorSize_ = OutDstVectorSize; + static constexpr int MThreadSliceSize_ = MThreadSliceSize; + static constexpr int KThreadSliceSize_ = KThreadSliceSize; +}; + +using ReduceAdd = ck::reduce::Add; +using ReduceMin = ck::reduce::Min; +using ReduceMax = ck::reduce::Max; +using ReduceAMax = ck::reduce::AMax; + +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnarySqrt = ck::tensor_operation::element_wise::UnarySqrt; +using UnaryDivide = ck::tensor_operation::element_wise::UnaryDivide; +using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs; + +#define QUICK_REDUCE_TEST 1 + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acf55d068392c191912c3f04c167ac185997c031 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using reduce_configuration_1_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +void add_device_reduce_instance_multiblock_atomic_add( + std::vector>& + device_op_instances) +{ + static_for<0, + std::tuple_size::value, + 1>{}([&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_multiblock_atomic_add{}))>; + + static_for<0, + std::tuple_size::value, + 1>{}([&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_atomic_add{}))>; + + using ReduceOpInstance = DeviceReduceMultiBlock; + + device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); + }); + }); +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f5102f49770af8f3e0cb057a74ef31c925469dd0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec513113d9e9ed08806b3560a65e6f1caa656e43 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a3d53b8c67b8416113256466f0e410ae998614a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bbf439896430e1b88ab6560b9279086958cb75b1 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55147a60e56422b703ae347e56eafc8a66f644f7 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4bff06c6afe30abbec780e147d975fcf7ce2d68f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..daffa1aa4d43fb5b751931c345938282421e17f4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..52c4171123f87bd19eed10057f2ef5df07e2b444 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2f358b06e0e55fa6df1ad1e842e7548c3f81c081 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..84c99dcc5759f0d4101d465e84e1e525ffd8fb7e --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +extern template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dfcc8dd8548624f19e16f1521197018e6e9fb7d7 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef QUICK_REDUCE_TEST +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 2, 2, 2, 1>, + ReductionConfiguration_2<0, 1, 1, 2, 1>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<1, 1, 1, 1, 3> + // clang-format on + >; +#else +using reduce_configuration_2_instances_threadwise = std::tuple< + // clang-format off + // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize + ReductionConfiguration_2<0, 4, 4, 8, 1>, + ReductionConfiguration_2<0, 4, 4, 4, 1>, + ReductionConfiguration_2<0, 2, 2, 2, 1>, + + ReductionConfiguration_2<1, 4, 1, 1, 8>, + ReductionConfiguration_2<1, 4, 1, 1, 4>, + ReductionConfiguration_2<1, 2, 1, 1, 2>, + + // special instances + ReductionConfiguration_2<0, 1, 1, 3, 1>, + ReductionConfiguration_2<0, 1, 1, 5, 1>, + ReductionConfiguration_2<0, 1, 1, 7, 1>, + ReductionConfiguration_2<0, 1, 1, 11, 1>, + + ReductionConfiguration_2<1, 1, 1, 1, 3>, + ReductionConfiguration_2<1, 1, 1, 1, 5>, + ReductionConfiguration_2<1, 1, 1, 1, 7>, + ReductionConfiguration_2<1, 1, 1, 1, 11> + // clang-format on + >; +#endif + +template +void add_device_reduce_instance_threadwise( + std::vector>& + device_op_instances) +{ + using cfg1 = ReductionConfiguration_1<256, 256, 1>; + + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_threadwise{}))>; + + using ReduceOpInstance = DeviceReduceThreadWise; + + device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); + }); +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4168508b28dfd3f255696a76e198e5cf6c5692b9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..317006e3a5c7c91543eb0bf81040ce86d1a41dc9 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fc7718ddc044799dde80dff814b41d2beb37f6c7 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e6616386ca44280dc5d3c1389acff83e58d97a27 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a9441b8e8eac32e6d96dfb923f87bd68bb32b2d3 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6820ace8cf0358bd3d2ecf97c3005699fe9b7d54 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab3d4e6e2c4f3ce4565e6ceec25527587be8b499 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee08c9635b9ee109c19153fc4b2f959731e5380e --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1007ca27bb976c6ea515b521762a02d385022057 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1d562c49991c9b4f34db407a5847ed1b1144a7ff --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5aac638b1eb444289b926c37a171ac6775ac86f2 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a3c7640973756e332b695b090aff45e037a8a9f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4685d7b5d551b211f96c4ad9c590b42e77cb8431 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1de338fb4889dd9c2954b030d8b31866afbd07a6 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e86c41a949775a3e4cef55c7f0e91484212a7a78 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ca9008560bbef4fab228000b88f7f059868aecd --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..38380e71ec0f289da2f0c04fe0db7d028bf4453c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..04c5f3e6585ec2f336b9888405ee4b33657a5aa5 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fef5d408845cd7a76427ea5142212728345efe6d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2416f614c34eae56cedfbe41f3c2d1b8f365ef96 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbd0285ae82a340863fd5c41bc0a1e48a3401dbe --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..103b85a011d578477268ad209725b80e92c52e2b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e01f590f0eae7ecf47683d3463fbe03960361fcb --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..14a7459bb8a7b631f92f45792e12880a65a09359 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7dfd806012012016038b5834e7577cdb77708cff --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7670a27c844bd2cb6a71bae3d33b5138a2fa6da0 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8bb85f377928a8c3d1394f9e27da16b9795dcd85 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a005ba8d426e7f9447e80c70946e95ba463396d1 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e8c07eb4f489afe9dfcc6cba3a70f1a6fd1c7b4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a69f88f5a9c6d8930ada6c95ec0c0b05a3557b83 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp new file mode 100644 index 0000000000000000000000000000000000000000..734b31c1e977b7a4a6d80494208d5560be17dcc4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp new file mode 100644 index 0000000000000000000000000000000000000000..237bd969668d5a24d58234809c0628b7e7239e2f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +extern template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..36eb092f0f0707b266ef9a1d7c3b5248947c1e87 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_instances( + std::vector>&); +void add_device_softmax_f16_f16_rank4_instances( + std::vector>&); + +void add_device_softmax_f32_f32_rank3_instances( + std::vector>&); +void add_device_softmax_f32_f32_rank4_instances( + std::vector>&); + +void add_device_softmax_i8_i8_rank3_instances( + std::vector>&); +void add_device_softmax_i8_i8_rank4_instances( + std::vector>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device:: + DeviceSoftmax> +{ + using DeviceOp = + DeviceSoftmax; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + if constexpr(Rank == 3) + add_device_softmax_f16_f16_rank3_instances(op_ptrs); + else if constexpr(Rank == 4) + add_device_softmax_f16_f16_rank4_instances(op_ptrs); + } + else if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + if constexpr(Rank == 3) + add_device_softmax_f32_f32_rank3_instances(op_ptrs); + else if constexpr(Rank == 4) + add_device_softmax_f32_f32_rank4_instances(op_ptrs); + } + else if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + if constexpr(Rank == 3) + add_device_softmax_i8_i8_rank3_instances(op_ptrs); + else if constexpr(Rank == 4) + add_device_softmax_i8_i8_rank4_instances(op_ptrs); + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83f52fc3ee7a9547161ce1284c2844976819c0ec --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_instances( + std::vector>& instances); +void add_device_softmax_f16_f16_rank4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..046ff5780556331a6afc697f87dbbf02dcf1ff25 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8e6a226f6a13fb1dcd7d61b56b476f38072b6a98 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..518fa5f98679099189f15dfe5a57e3ac0324e3fa --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..10016cdd707cbcb46282e540e15e6b79a67c6a78 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank4_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdd5a3cd7b6fb6806db3de0a10b7ed87b29c5c2d --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank4_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a8be272e02014495cb42d0509b50075f2d7da707 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank4_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec8296ff22fb6591152a3dcfad63c250c70cabb8 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank4_reduce4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b3877c4bb3f6217427f69c07f7073751b254c938 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_softmax_f16_f16_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + // fallback kernel + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8>, + // Reduction on middle dimensions + // InSrcVectorDim is 0 since we want to coalesce reads on M dimension + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 4, 0, 1, 1>, + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 4, 0, 8, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6d9a359f4622ec9534ab090d46e363bd9101177 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank3_instances( + std::vector>& instances); +void add_device_softmax_f32_f32_rank4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6621a2c867ac90402a853e5c2065663a96220927 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank3_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3dfac98ed8beeee6060b6ba3f6b9b1ff19f660fe --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank3_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d2a0c932500e874641d8cef9fd65b4d280c89b6 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank3_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..97dd3dcb18aed5c228313a961994096445948bbf --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank4_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58f8760acccd1b571df40ea9a08946c1f5f9764f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank4_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df8d31f0da76b0f027b4ee4818006dd331e38454 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank4_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1bd773227e172278a466d7650e90fd39ec3e56f8 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank4_reduce4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..16f129d2d07c19a926c75f1e878e8db3fb9bf821 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_softmax_f32_f32_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4>, + // Reduction on middle dimensions + // InSrcVectorDim is 0 since we want to coalesce reads on M dimension + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 4, 0, 1, 1>, + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 4, 0, 4, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f80f712ff5e501d6516d9c7210a59276a96b6541 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank3_instances( + std::vector>& instances); +void add_device_softmax_i8_i8_rank4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f9952e7d58a448cfea0e3c6ba84af24e9de4e52 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank3_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2cbd13a1ba5226316363cfb07dfd23adfac4aa9f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank3_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b12522a85987f85e83a6f8afb34b97844507c49 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank3_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54d477f80c5b222cd5a76430dae3bbd4db81068c --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank4_reduce1_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4ffc44e3a9217e689297bceb81bd6a38c41cd1c3 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank4_reduce2_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08cbb81272f9ff5e7ce9a9fdd9075d81f1ed7211 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank4_reduce3_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..187d034b95ac506426da4fa1b8240b61e2517e46 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank4_reduce4_instances( + std::vector>& instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7fc9ed69198f1e7bf447a6baa9dedb3286440a94 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_softmax_i8_i8_instances = std::tuple< + // clang-format off + // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + // fallback kernel + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 1, 1>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 16, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 16, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 32, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 64, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 16, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 32, 1, 16, 16>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 64, 1, 16, 16>, + // Reduction on middle dimensions + // InSrcVectorDim is 0 since we want to coalesce reads on M dimension + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 8, 0, 1, 1>, + DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 32, 8, 32, 8, 0, 16, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03be6e2bc7c15ee350569f7ed97ebf2d3ad3b515 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp" diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/algorithm.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/algorithm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86f04dd362397e774c9821a16a7c412dea000937 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/algorithm.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace ranges { +template +auto copy(InputRange&& range, OutputIterator iter) + -> decltype(std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter)) +{ + return std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter); +} + +template +auto fill(OutputRange&& range, const T& init) + -> std::void_t(range)), + std::end(std::forward(range)), + init))> +{ + std::fill(std::begin(std::forward(range)), + std::end(std::forward(range)), + init); +} + +template +auto transform(InputRange&& range, OutputIterator iter, UnaryOperation unary_op) + -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op)) +{ + return std::transform(std::begin(range), std::end(range), iter, unary_op); +} + +} // namespace ranges +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/check_err.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/check_err.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a89d03d324f33f0be95f88c328a9d7aee38e7460 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/check_err.hpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/library/utility/ranges.hpp" + +namespace ck { +namespace utils { + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_floating_point_v> && + !std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = *std::next(std::begin(out), i); + const double r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, bhalf_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits>::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_integral_v> && + !std::is_same_v, bhalf_t>) +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v, int4_t> +#endif + , + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double atol = 0) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const int64_t o = *std::next(std::begin(out), i); + const int64_t r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + + if(err > atol) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r + << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << "max err: " << max_err << std::endl; + } + return res; +} + +} // namespace utils +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/conv_common.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/conv_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6fad9f7d77d349853e63512ae486a57e7ceb7b6f --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/conv_common.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" + +template +constexpr auto get_convolution_output_default_4d_tensor_descriptor( + const ck::TensorDescriptor& in_desc, + const ck::TensorDescriptor& wei_desc, + const ConvStrides& conv_strides, + const ConvDilations conv_dilations, + const LeftPads& left_pads, + const RightPads& right_pads) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + assert(in_desc.GetNumOfDimension() == 4); + assert(wei_desc.GetNumOfDimension() == 4); + assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1)); + + const auto N = in_desc.GetLength(I0); + const auto Hi = in_desc.GetLength(I2); + const auto Wi = in_desc.GetLength(I3); + + const auto K = wei_desc.GetLength(I0); + const auto Y = wei_desc.GetLength(I2); + const auto X = wei_desc.GetLength(I3); + + const auto LeftPadH = left_pads[I0]; + const auto LeftPadW = left_pads[I1]; + + const auto RightPadH = right_pads[I0]; + const auto RightPadW = right_pads[I1]; + + const auto YEff = (Y - I1) * conv_dilations[I0] + I1; + const auto XEff = (X - I1) * conv_dilations[I1] + I1; + + const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; + const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; + + return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); +} + +template +constexpr std::size_t +calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const index_t N = out_desc.GetLength(I0); + const index_t K = out_desc.GetLength(I1); + const index_t Ho = out_desc.GetLength(I2); + const index_t Wo = out_desc.GetLength(I3); + + const index_t C = wei_desc.GetLength(I1); + const index_t Y = wei_desc.GetLength(I2); + const index_t X = wei_desc.GetLength(I3); + + return std::size_t(2) * N * K * Ho * Wo * C * Y * X; +} diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b4f63b28b8d2f1e59cef67b99a577314faf4ae2 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace utils { +namespace conv { + +namespace detail { + +template +std::vector get_layout_transpose_gnchw_to_old() +{ + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 3, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 4, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 5, 2, 3, 4}; + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3, 4}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3, 4, 5}; + } + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 3, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 4, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 5, 2, 3, 4}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {2, 0, 3, 1}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {3, 0, 4, 1, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {4, 0, 5, 1, 2, 3}; + } + else + { + printf("%s\n", __func__); + throw std::runtime_error("wrong! unsupported layout"); + } +} + +} // namespace detail + +// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW +// regardless of physical layout +template +HostTensorDescriptor +make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", InLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX +// regardless of physical layout +template +HostTensorDescriptor +make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", WeiLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW +// regardless of physical layout +template +HostTensorDescriptor +make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.end(), + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", OutLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +} // namespace conv +} // namespace utils +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_parameter.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_parameter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4a2b56f75a6b40a476080e74d386fab524286a4 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/convolution_parameter.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/numeric.hpp" + +namespace ck { +namespace utils { +namespace conv { + +struct ConvParam +{ + ConvParam(); + ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); + + ck::index_t num_dim_spatial_; + ck::index_t G_; + ck::index_t N_; + ck::index_t K_; + ck::index_t C_; + + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + std::vector output_spatial_lengths_; + + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const; + + std::size_t GetFlops() const; + + template + std::size_t GetInputByte() const + { + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * + (G_ * N_ * C_ * + ck::accumulate_n( + std::begin(input_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>())); + } + + template + std::size_t GetWeightByte() const + { + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * + (G_ * K_ * C_ * + ck::accumulate_n( + std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>())); + } + + template + std::size_t GetOutputByte() const + { + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G_ * N_ * K_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_), + static_cast(1), + std::multiplies())); + } + + template + std::size_t GetByte() const + { + return GetInputByte() + GetWeightByte() + + GetOutputByte(); + } +}; + +std::string get_conv_param_parser_helper_msg(); + +ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]); + +} // namespace conv +} // namespace utils +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/device_memory.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/device_memory.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c4ece44068dd2dafe5869f6943a0b27b2508355 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/device_memory.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer() const; + std::size_t GetBufferSize() const; + void ToDevice(const void* p) const; + void FromDevice(void* p) const; + void SetZero() const; + template + void SetValue(T x) const; + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +template +void DeviceMem::SetValue(T x) const +{ + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); +} diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/fill.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54d58f362ccaffad209df586e2b9a399120d1d88 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/fill.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace utils { + +template +struct FillUniformDistribution +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillMonotonicSeq +{ + T init_value_{0}; + T step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, n = init_value_]() mutable { + auto tmp = n; + n += step_; + return tmp; + }); + } +}; + +template +struct FillConstant +{ + T value_{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, value_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_common_util.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_common_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f4466e8da0e2b77c1ad981cfb6f730a74a92d71 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_common_util.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { + +namespace host_common { + +template +static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + std::cout << "Write output to file " << fileName << std::endl; + } + else + { + std::cout << "Could not open file " << fileName << " for writing" << std::endl; + } +}; + +template +static inline T getSingleValueFromString(const std::string& valueStr) +{ + std::istringstream iss(valueStr); + + T val; + + iss >> val; + + return (val); +}; + +template +static inline std::vector getTypeValuesFromString(const char* cstr_values) +{ + std::string valuesStr(cstr_values); + + std::vector values; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = valuesStr.find(',', pos); + while(new_pos != std::string::npos) + { + const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); + + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + pos = new_pos + 1; + new_pos = valuesStr.find(',', pos); + }; + + std::string sliceStr = valuesStr.substr(pos); + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + return (values); +} + +template +static inline std::vector> +get_index_set(const std::array& dim_lengths) +{ + static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); + + if constexpr(NDim == 1) + { + std::vector> index_set; + + for(int i = 0; i < dim_lengths[0]; i++) + { + std::array index{i}; + + index_set.push_back(index); + }; + + return index_set; + } + else + { + std::vector> index_set; + std::array partial_dim_lengths; + + std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin()); + + std::vector> partial_index_set; + + partial_index_set = get_index_set(partial_dim_lengths); + + for(index_t i = 0; i < dim_lengths[0]; i++) + for(const auto& partial_index : partial_index_set) + { + std::array index; + + index[0] = i; + + std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1); + + index_set.push_back(index); + }; + + return index_set; + }; +}; + +template +static inline size_t get_offset_from_index(const std::array& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += index[i] * strides[i]; + + return (offset); +}; + +} // namespace host_common +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_conv.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8348a3089f4606563caa5582ac427154308e7fa5 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_conv.hpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "host_tensor.hpp" +#include "conv_common.hpp" + +template +void host_conv_nchw_kcyx_nkhw(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += ck::type_convert(in(n, c, hi, wi)) * + ck::type_convert(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = ck::type_convert(v); + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); +} + +template +void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + const auto Di = in.mDesc.GetLengths()[1]; + const auto Hi = in.mDesc.GetLengths()[2]; + const auto Wi = in.mDesc.GetLengths()[3]; + const auto Z = wei.mDesc.GetLengths()[1]; + const auto Y = wei.mDesc.GetLengths()[2]; + const auto X = wei.mDesc.GetLengths()[3]; + const auto C = wei.mDesc.GetLengths()[4]; + + auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) { + // do__ must be converted to signed integer, otherwise zmin might be wrong in cases + // negative values. + const int do_ = static_cast(do_tmp); + const int ho = static_cast(ho_tmp); + const int wo = static_cast(wo_tmp); + const int zmin = + std::max(0, + (in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) / + conv_dilations[I0]); + const int ymin = + std::max(0, + (in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) / + conv_dilations[I1]); + const int xmin = + std::max(0, + (in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) / + conv_dilations[I2]); + const int zmax = + std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]); + const int ymax = + std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]); + const int xmax = + std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]); + const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0]; + const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1]; + const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2]; + + double v = 0; + + const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C; + const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C; + + int di = di_min; + for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0]) + { + const TIn* in_n_di = in_n + di * Hi * Wi * C; + const TWei* wei_k_z = wei_k + z * Y * X * C; + int hi = hi_min; + + for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1]) + { + const TIn* in_n_di_hi = in_n_di + hi * Wi * C; + const TWei* wei_k_z_y = wei_k_z + y * X * C; + int wi = wi_min; + + for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2]) + { + const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C; + const TWei* wei_k_z_y_x = wei_k_z_y + x * C; + + for(int c = 0; c < C; ++c) + { + v += static_cast(in_n_di_hi_wi[c]) * + static_cast(wei_k_z_y_x[c]); + } + } + } + } + + out(n, do_, ho, wo, k) = v; + }; + + make_ParallelTensorFunctor(f_ndhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3], + out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4); +} diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_gemm.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..44036d0234375824ed3aa773ffe71de02b84c570 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_gemm.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "host_tensor.hpp" + +template +void host_gemm_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + a_element_op(v_a, static_cast(a_m_k(m, k))); + b_element_op(v_b, static_cast(b_k_n(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + c_element_op(v_c, v_acc); + + c_m_n(m, n) = v_c; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_reduction.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_reduction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7c0c969ac592a459c5df37fbfcf807f2f500df37 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_reduction.hpp @@ -0,0 +1,374 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +static void get_all_indexes(const std::array& dimLengths, + std::vector>& indexes) +{ + static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); + + if constexpr(NDim == 1) + { + for(size_t i = 0; i < dimLengths[0]; i++) + { + std::array index{i}; + + indexes.push_back(index); + }; + } + else + { + std::array partial_dim_lengths; + + for(int i = 0; i < NDim - 1; i++) + partial_dim_lengths[i] = dimLengths[i + 1]; + + std::vector> partial_indexes; + + get_all_indexes(partial_dim_lengths, partial_indexes); + + for(size_t i = 0; i < dimLengths[0]; i++) + for(const auto& index : partial_indexes) + { + std::array extIndex; + + extIndex[0] = i; + + for(int k = 0; k < NDim - 1; k++) + extIndex[k + 1] = index[k]; + + indexes.push_back(extIndex); + }; + }; +}; + +template +static size_t get_offset_from_index(const std::array& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +static size_t get_offset_from_index(const std::vector& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += strides[i] * index[i]; + + return (offset); +}; + +template +struct ReductionHost +{ + using IndexDataType = int32_t; + + static constexpr int NumInvariantDim = Rank - NumReduceDim; + + std::vector outStrides; + + IndexDataType divider; + + std::array reduceLengths; + std::array reduceStrides; + std::array invariantLengths; + std::array invariantStrides; + + std::vector> reduce_dim_indexes; + std::vector> invariant_dim_indexes; + + ReductionHost(HostTensorDescriptor& inDesc, + HostTensorDescriptor& outDesc, + const std::array invariantDims, + const std::array reduceDims) + { + // this->outLengths = to_int_vector(outDesc.GetLengths()); + this->outStrides = outDesc.GetStrides(); + + int product = 1; + + for(int i = 0; i < NumReduceDim; i++) + { + reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]]; + reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]]; + product *= inDesc.GetLengths()[reduceDims[i]]; + }; + + divider = product; + + for(int i = 0; i < NumInvariantDim; i++) + { + invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]]; + invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]]; + }; + + reduce_dim_indexes.clear(); + get_all_indexes(reduceLengths, reduce_dim_indexes); + + if constexpr(NumInvariantDim > 0) + { + invariant_dim_indexes.clear(); + get_all_indexes(invariantLengths, invariant_dim_indexes); + }; + }; + + void Run(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + if constexpr(OutputIndex) + { + RunImpl_with_index( + alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op); + } + else + { + RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op); + }; + }; + + void RunImpl_with_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = type_convert(in_data[offset_reduce]); + + in_elementwise_op(currVal, currVal); + + auto currIndex = static_cast(i); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + out_indices[0] = accuIndex; + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + in_elementwise_op(currVal, currVal); + + auto currIndex = static_cast(i); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + out_indices[dst_offset] = accuIndex; + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; + + void RunImpl_no_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) + { + using ck::float_equal_one; + using ck::float_equal_zero; + using ck::type_convert; + + using Accumulation = + ck::detail::AccumulateWithNanCheck; + + if constexpr(NumInvariantDim == 0) + { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = type_convert(in_data[offset_reduce]); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[0]) * type_convert(beta); + + out_data[0] = type_convert(accuVal); + } + else + { + auto thread_reduce_func = [&](auto invariant_index) { + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); + + auto offset_invariant = + get_offset_from_index(invariantStrides, invariant_index); + + for(const auto& reduce_index : reduce_dim_indexes) + { + auto offset_reduce = + get_offset_from_index(reduceStrides, reduce_index); + + auto currVal = + type_convert(in_data[offset_invariant + offset_reduce]); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + }; + + acc_elementwise_op(accuVal, accuVal); + + if(!float_equal_one{}(alpha)) + accuVal *= type_convert(alpha); + + auto dst_offset = + get_offset_from_index(outStrides, invariant_index); + + if(!float_equal_zero{}(beta)) + accuVal += type_convert(out_data[dst_offset]) * + type_convert(beta); + + out_data[dst_offset] = type_convert(accuVal); + }; + + std::size_t num_thread = 1; + std::size_t work_per_thread = + (invariant_dim_indexes.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = + std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + thread_reduce_func(invariant_dim_indexes[iw]); + } + }; + + threads[it] = joinable_thread(f); + } + }; + }; +}; diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a8c7fd039537f7a691a08784ce33f39f444ab88b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor.hpp @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/span.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/ranges.hpp" + +template +std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << v; + } + return os; +} + +template +std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << static_cast(v); + } + return os; +} + +template +auto call_f_unpack_args_impl(F f, T args, std::index_sequence) +{ + return f(std::get(args)...); +} + +template +auto call_f_unpack_args(F f, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); +} + +template +auto construct_f_unpack_args_impl(T args, std::index_sequence) +{ + return F(std::get(args)...); +} + +template +auto construct_f_unpack_args(F, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return construct_f_unpack_args_impl(args, std::make_index_sequence{}); +} + +struct HostTensorDescriptor +{ + HostTensorDescriptor() = default; + + void CalculateStrides(); + + template >> + HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template , std::size_t>>> + HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template && + std::is_convertible_v>> + HostTensorDescriptor(const std::initializer_list& lens, + const std::initializer_list& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + template , std::size_t> && + std::is_convertible_v, std::size_t>>> + HostTensorDescriptor(const Lengths& lens, const Strides& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + std::size_t GetNumOfDimension() const; + std::size_t GetElementSize() const; + std::size_t GetElementSpaceSize() const; + + const std::vector& GetLengths() const; + const std::vector& GetStrides() const; + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + assert(sizeof...(Is) == this->GetNumOfDimension()); + std::initializer_list iss{static_cast(is)...}; + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + std::size_t GetOffsetFromMultiIndex(std::vector iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + + private: + std::vector mLens; + std::vector mStrides; +}; + +template +HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, + const New2Old& new2old) +{ + std::vector new_lengths(a.GetNumOfDimension()); + std::vector new_strides(a.GetNumOfDimension()); + + for(std::size_t i = 0; i < a.GetNumOfDimension(); i++) + { + new_lengths[i] = a.GetLengths()[new2old[i]]; + new_strides[i] = a.GetStrides()[new2old[i]]; + } + + return HostTensorDescriptor(new_lengths, new_strides); +} + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; + +template +struct ParallelTensorFunctor +{ + F mF; + static constexpr std::size_t NDIM = sizeof...(Xs); + std::array mLens; + std::array mStrides; + std::size_t mN1d; + + ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) + { + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + mN1d = mStrides[0] * mLens[0]; + } + + std::array GetNdIndices(std::size_t i) const + { + std::array indices; + + for(std::size_t idim = 0; idim < NDIM; ++idim) + { + indices[idim] = i / mStrides[idim]; + i -= indices[idim] * mStrides[idim]; + } + + return indices; + } + + void operator()(std::size_t num_thread = 1) const + { + std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + call_f_unpack_args(mF, GetNdIndices(iw)); + } + }; + threads[it] = joinable_thread(f); + } + } +}; + +template +auto make_ParallelTensorFunctor(F f, Xs... xs) +{ + return ParallelTensorFunctor(f, xs...); +} + +template +struct Tensor +{ + using Descriptor = HostTensorDescriptor; + using Data = std::vector; + + template + Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + { + } + + template + Tensor(std::initializer_list lens, std::initializer_list strides) + : mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) + { + } + + template + Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + { + } + + template + Tensor(const Lengths& lens, const Strides& strides) + : mDesc(lens, strides), mData(GetElementSpaceSize()) + { + } + + Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} + + template + Tensor CopyAsType() const + { + Tensor ret(mDesc); + + ck::ranges::transform( + mData, ret.mData.begin(), [](auto value) { return ck::type_convert(value); }); + + return ret; + } + + Tensor() = delete; + Tensor(const Tensor&) = default; + Tensor(Tensor&&) = default; + + ~Tensor() = default; + + Tensor& operator=(const Tensor&) = default; + Tensor& operator=(Tensor&&) = default; + + template + explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) + { + } + + decltype(auto) GetLengths() const { return mDesc.GetLengths(); } + + decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + + std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); } + + std::size_t GetElementSize() const { return mDesc.GetElementSize(); } + + std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } + + std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } + + void SetZero() { ck::ranges::fill(mData, 0); } + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void GenerateTensorValue(G g, std::size_t num_thread = 1) + { + switch(mDesc.GetNumOfDimension()) + { + case 1: { + auto f = [&](auto i) { (*this)(i) = g(i); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); + break; + } + case 2: { + auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); + break; + } + case 3: { + auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; + make_ParallelTensorFunctor( + f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); + break; + } + case 4: { + auto f = [&](auto i0, auto i1, auto i2, auto i3) { + (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3])(num_thread); + break; + } + case 5: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4])(num_thread); + break; + } + case 6: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4], + mDesc.GetLengths()[5])(num_thread); + break; + } + default: throw std::runtime_error("unspported dimension"); + } + } + + template + T& operator()(Is... is) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + template + const T& operator()(Is... is) const + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + T& operator()(std::vector idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + const T& operator()(std::vector idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + typename Data::iterator begin() { return mData.begin(); } + + typename Data::iterator end() { return mData.end(); } + + typename Data::pointer data() { return mData.data(); } + + typename Data::const_iterator begin() const { return mData.begin(); } + + typename Data::const_iterator end() const { return mData.end(); } + + typename Data::const_pointer data() const { return mData.data(); } + + typename Data::size_type size() const { return mData.size(); } + + template + auto AsSpan() const + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::add_const_t>; + return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; + } + + template + auto AsSpan() + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::remove_reference_t; + return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; + } + + Descriptor mDesc; + Data mData; +}; diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor_generator.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor_generator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4259862e65e331b30377997c512459e2a21e9383 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/host_tensor_generator.hpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +template +struct GeneratorTensor_0 +{ + template + T operator()(Is...) + { + return T{0}; + } +}; + +template +struct GeneratorTensor_1 +{ + T value = 1; + + template + T operator()(Is...) + { + return value; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + int8_t operator()(Is...) + { + return value; + } +}; + +template +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + T operator()(Is...) + { + return static_cast((std::rand() % (max_value - min_value)) + min_value); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + int8_t operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + T operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return static_cast(min_value + tmp * (max_value - min_value)); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + +template +struct GeneratorTensor_4 +{ + std::default_random_engine generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev) : generator(1), distribution(mean, stddev){}; + + template + T operator()(Is...) + { + float tmp = distribution(generator); + + return ck::type_convert(tmp); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {static_cast(Xs)...}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + return dims[Dim]; + } +}; + +template +struct GeneratorTensor_Diagonal +{ + T value{1}; + + template + T operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + size_t start_dim = dims.size() - NumEffectiveDim; + bool pred = true; + for(size_t i = start_dim + 1; i < dims.size(); i++) + { + pred &= (dims[start_dim] == dims[i]); + } + return pred ? value : T{0}; + } +}; diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/iterator.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fdc88ea76825e2a1cdf659f95fa2928038867aa --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/iterator.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/type.hpp" + +namespace ck { + +template +using iter_value_t = typename std::iterator_traits>::value_type; + +template +using iter_reference_t = decltype(*std::declval()); + +template +using iter_difference_t = typename std::iterator_traits>::difference_type; + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/literals.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/literals.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a73a2ea054150665557b8e1da473f2cb9877526b --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/literals.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace literals { +// [P0330] Literal Suffix for (signed) size_t (C++23) +// ref: https://wg21.link/p0330r8 +inline constexpr std::size_t operator""_uz(unsigned long long size) +{ + return static_cast(size); +} + +inline constexpr std::size_t operator""_zu(unsigned long long size) +{ + return static_cast(size); +} +} // namespace literals +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/numeric.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/numeric.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70a7e87ab1c8d94569e705b3f6832c78bac05853 --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/numeric.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck { +template +auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) + -> decltype(std::accumulate(first, std::next(first, count), init, op)) +{ + return std::accumulate(first, std::next(first, count), init, op); +} +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/op_instance_engine.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/op_instance_engine.hpp new file mode 100644 index 0000000000000000000000000000000000000000..78812e8c81d600ab1258400205b98994ab120d3a --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/op_instance_engine.hpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/functional2.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace utils { + +struct ProfileBestConfig +{ + std::string best_op_name; + float best_avg_time = std::numeric_limits::max(); + float best_tflops = std::numeric_limits::max(); + float best_gb_per_sec = std::numeric_limits::max(); +}; + +/** + * @brief This class describes an operation instance(s). + * + * Op instance defines a particular specializations of operator + * template. Thanks to this specific input/output data types, data + * layouts and modifying elementwise operations it is able to create + * it's input/output tensors, provide pointers to instances which + * can execute it and all operation specific parameters. + */ +template +class OpInstance +{ + public: + template + using TensorPtr = std::unique_ptr>; + using InTensorsTuple = std::tuple...>; + using DeviceMemPtr = std::unique_ptr; + using DeviceBuffers = std::vector; + + OpInstance() = default; + OpInstance(const OpInstance&) = default; + OpInstance& operator=(const OpInstance&) = default; + virtual ~OpInstance(){}; + + virtual InTensorsTuple GetInputTensors() const = 0; + virtual TensorPtr GetOutputTensor() const = 0; + virtual std::unique_ptr + MakeInvokerPointer(tensor_operation::device::BaseOperator*) const = 0; + virtual std::unique_ptr + MakeArgumentPointer(tensor_operation::device::BaseOperator*, + const DeviceBuffers&, + const DeviceMemPtr&) const = 0; + virtual std::size_t GetFlops() const = 0; + virtual std::size_t GetBtype() const = 0; +}; + +/** + * @brief A generic operation instance run engine. + */ +template +class OpInstanceRunEngine +{ + public: + using OpInstanceT = OpInstance; + template + using TensorPtr = std::unique_ptr>; + using DeviceMemPtr = std::unique_ptr; + using InTensorsTuple = std::tuple...>; + using DeviceBuffers = std::vector; + using InArgsTypesTuple = std::tuple; + + OpInstanceRunEngine() = delete; + + template > + OpInstanceRunEngine(const OpInstanceT& op_instance, + const ReferenceOp& reference_op = ReferenceOp{}, + bool do_verification = true) + : op_instance_{op_instance} + { + in_tensors_ = op_instance_.GetInputTensors(); + out_tensor_ = op_instance_.GetOutputTensor(); + + if constexpr(std::is_invocable_v&..., + Tensor&>) + { + if(do_verification) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } + } + AllocateDeviceInputTensors(std::make_index_sequence{}); + out_device_buffer_ = std::make_unique(sizeof(OutDataType) * + out_tensor_->mDesc.GetElementSpaceSize()); + out_device_buffer_->SetZero(); + } + + virtual ~OpInstanceRunEngine(){}; + + template + bool Test(const std::vector& op_ptrs) + { + bool res{true}; + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; + invoker->Run(argument.get()); + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Test: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); + std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; + res = res && inst_res; + out_device_buffer_->SetZero(); + } + else + { + std::cout << "Given conv problem is not supported by instance: \n\t>>>>" + << op_ptr->GetTypeString() << std::endl; + } + } + return res; + } + + template + ProfileBestConfig Profile(const std::vector& op_ptrs, + bool time_kernel = false, + bool do_verification = false, + bool do_log = false) + { + ProfileBestConfig best_config; + + for(auto& op_ptr : op_ptrs) + { + auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get()); + auto argument = op_instance_.MakeArgumentPointer( + op_ptr.get(), in_device_buffers_, out_device_buffer_); + if(op_ptr->IsSupportedArgument(argument.get())) + { + std::string op_name = op_ptr->GetTypeString(); + float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flops = op_instance_.GetFlops(); + std::size_t num_btype = op_instance_.GetBtype(); + float tflops = static_cast(flops) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(avg_time < best_config.best_avg_time) + { + best_config.best_op_name = op_name; + best_config.best_tflops = tflops; + best_config.best_gb_per_sec = gb_per_sec; + best_config.best_avg_time = avg_time; + } + + if(do_verification) + { + out_device_buffer_->FromDevice(out_tensor_->mData.data()); + if(!ref_output_) + { + throw std::runtime_error( + "OpInstanceRunEngine::Profile: Reference value not availabe." + " You have to provide reference function."); + } + // TODO: enable flexible use of custom check_error functions + CheckErr(out_tensor_->mData, ref_output_->mData); + + if(do_log) {} + } + out_device_buffer_->SetZero(); + } + } + return best_config; + } + + void SetAtol(double a) { atol_ = a; } + void SetRtol(double r) { rtol_ = r; } + + private: + template + void CallRefOpUnpackArgs(const F& f, std::index_sequence) const + { + f(*std::get(in_tensors_)..., *ref_output_); + } + + template + void AllocateDeviceInputTensors(std::index_sequence) + { + (AllocateDeviceInputTensorsImpl(), ...); + } + + template + void AllocateDeviceInputTensorsImpl() + { + const auto& ts = std::get(in_tensors_); + in_device_buffers_ + .emplace_back( + std::make_unique(sizeof(std::tuple_element_t) * + ts->mDesc.GetElementSpaceSize())) + ->ToDevice(ts->mData.data()); + } + + static constexpr std::size_t kNInArgs_ = std::tuple_size_v; + const OpInstanceT& op_instance_; + double rtol_{1e-5}; + double atol_{1e-8}; + + InTensorsTuple in_tensors_; + TensorPtr out_tensor_; + TensorPtr ref_output_; + + DeviceBuffers in_device_buffers_; + DeviceMemPtr out_device_buffer_; + + template + bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const + { + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); + } +}; + +} // namespace utils +} // namespace ck diff --git a/3rdparty/composable_kernel/library/include/ck/library/utility/ranges.hpp b/3rdparty/composable_kernel/library/include/ck/library/utility/ranges.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55c322f1ace041b9c4f346a5df4250c2323759dd --- /dev/null +++ b/3rdparty/composable_kernel/library/include/ck/library/utility/ranges.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/library/utility/iterator.hpp" + +namespace ck { +namespace ranges { + +template +using iterator_t = decltype(std::begin(std::declval())); + +template +using sentinel_t = decltype(std::end(std::declval())); + +template +using range_size_t = decltype(std::size(std::declval())); + +template +using range_difference_t = ck::iter_difference_t>; + +template +using range_value_t = iter_value_t>; + +template +using range_reference_t = iter_reference_t>; + +template +struct is_range : std::false_type +{ +}; + +template +struct is_range< + T, + std::void_t())), decltype(std::end(std::declval()))>> + : std::true_type +{ +}; + +template +inline constexpr bool is_range_v = is_range::value; + +template +struct is_sized_range : std::false_type +{ +}; + +template +struct is_sized_range()))>> + : std::bool_constant> +{ +}; +} // namespace ranges +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b7cd3a958e810c229d3620260e88d8e3e43bc1d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -0,0 +1,69 @@ +function(add_instance_library INSTANCE_NAME) + message("adding instance ${INSTANCE_NAME}") + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + target_compile_features(${INSTANCE_NAME} PUBLIC) + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + clang_tidy_check(${INSTANCE_NAME}) +endfunction(add_instance_library INSTANCE_NAME) + + +file(GLOB dir_list LIST_DIRECTORIES true *) +set(CK_DEVICE_INSTANCES) +FOREACH(subdir_path ${dir_list}) +set(target_dir) +IF(IS_DIRECTORY "${subdir_path}") + get_filename_component(target_dir ${subdir_path} NAME) + add_subdirectory(${target_dir}) + list(APPEND CK_DEVICE_INSTANCES $) +ENDIF() +ENDFOREACH() + +add_library(device_operations STATIC ${CK_DEVICE_INSTANCES}) +add_library(composablekernels::device_operations ALIAS device_operations) + + +set(DEV_OPS_INC_DIRS + ${PROJECT_SOURCE_DIR}/include/ck/ + ${PROJECT_SOURCE_DIR}/library/include/ck/ +) + +target_compile_features(device_operations PUBLIC) +set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(device_operations PUBLIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) + +#once new arches are enabled make this an option on the main cmake file +# and pass down here to be exported +target_compile_options(device_operations PRIVATE + --offload-arch=gfx908 + --offload-arch=gfx90a +) + +target_compile_options(device_operations PRIVATE --gpu-max-threads-per-block=1024) + +# install(TARGETS device_operations LIBRARY DESTINATION lib) +rocm_install(TARGETS device_operations + EXPORT device_operationsTargets) + +rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +rocm_install(EXPORT device_operationsTargets + FILE composable_kerneldevice_operationsTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0f2a7391999db23c9df9578b687ee11fa2198b26 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -0,0 +1,18 @@ +add_instance_library(device_batched_gemm_instance + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc8787458d7ffff072cf8c055c8def43ca487e51 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04200cfb52e26f227d89852e2845547e6624c22d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b86f3cc728fdac1117cbe927445b6830e81e9ab --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2afb1afbc3d2355f682af9cad21b62c0ab1ee8ea --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68d768949c859f5ca8aedfebfa888253dc6c276c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..737e5bfca3bccbafcd8179d80bb43b58bd545639 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e09d01736083f54cf2c1b848814b2d7fb683d3b8 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..984d66e2884273e006772530e6a978891766a8e0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12cada9c44d61d8b81ad2dfa4877dc57b26ed681 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13f198862e5583b1f63c51efaab11fa733976cfe --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ca1adc2f6d160a97a2f40b4075f07b185e68351 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe5de52796f6c9a0c94a84d883a7b1a714c4a652 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceBatchedGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b55c8e15e01d879cc11e57128c3e2dfdf81890f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9517e4577e5b9af188aaf07413e8c0b3cf6c32ee --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43b91244060ed90c9f2558dc99279ce0a086845b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..326500fcbf013e8540fff9252202ac7af33b2938 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using AData = int8_t; +using BData = int8_t; +using CData = int8_t; +using AccData = int32_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceBatchedGemmXdl< AData, BData, CData, AccData, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0e9b265af541a7717e9d25a3aef96e4196298ba --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_batched_gemm_add_relu_gemm_add_instance + device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1bfa88f49d4fb886f953661c03b8638516ba646 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + // no padding + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + // Padded fallback kernel + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector, + Row, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f59b742534b2c234b2b4908f5e540f591b104619 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = + std::tuple< + // clang-format off + //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + // no padding + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + // Padded fallback kernel + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector, + Col, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..865a31e79a5dfdeb3af54a7f59fc40dc4e90070b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_batched_gemm_gemm_instance + device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b96194c874c3daea7c9c4eb931368ce5645b50e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + // Padded fallback kernel + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0713dfcd99ca183d8f64cd23ac6067be990b39e7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = std::tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + // Padded fallback kernel + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..db3719cff8a16a7f8dd2a42cbb8da1009555070a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_batched_gemm_reduce_instance + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..521c3d9219b1cd23e30ca194d3e22d1f38fa9316 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..231d612d781bc93c2a232f7e1478fddfb09c1b5d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..165bc3957d36479b1cb3052c0085ade58fb71322 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..832fc3b066066f9ac8faae30cebff714b73b53c9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..29fce56610929805ea706dbbea6cdfe18fdcc44d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_batched_gemm_softmax_gemm_instance + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99e8712474996a6c30e8bf16ae69f5bc15925364 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| + //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> + // clang-format on + >; + +template +using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances = + std::tuple< + // clang-format off + //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| + //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> + // clang-format on + >; + +void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + false>{}); + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances< + false>{}); +} + +void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + true>{}); + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances< + true>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..76121ffc347e0b8019c2ed1d18b6e71a5b0fb518 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -0,0 +1,5 @@ +add_instance_library(device_batched_gemm_softmax_gemm_permute_instance + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53ad7ba5ffa9272da26189a45a75a9453af68f48 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); +} + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21da6895e66be8597fdeda60501e2d25716f3221 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); +} + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d12a2f244fc8b0dedc42ac716ea8354a5c1dee72 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt @@ -0,0 +1,10 @@ +add_instance_library(device_batchnorm_instance + device_batchnorm_forward_f16_instance.cpp + device_batchnorm_forward_f32_instance.cpp + device_batchnorm_forward_bf16_instance.cpp + device_batchnorm_forward_f64_instance.cpp + device_batchnorm_backward_f16_instance.cpp + device_batchnorm_backward_f32_instance.cpp + device_batchnorm_backward_bf16_instance.cpp + device_batchnorm_backward_f64_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b62c8b99cbdb6cf42922e843c11a0832614b6ea7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_bf16_blockwise_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_bf16_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_bf16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_bf16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d05b8b592c29aceef387a4fdec8c02854501995f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f16_blockwise_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f16_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3ef95d12e17acfebc029ac6e1a878194691739f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f32_blockwise_instances = std::tuple< + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f32_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f32_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f32_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f32_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41be396c24a28b5a7fbc47b10d7cf6ae5e6e07e9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f64_blockwise_instances = std::tuple< + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f64_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f64_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f64_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f64_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd1e05b1133ad2a2515a936ef968667a03bc11be --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_forward_bf16_blockwise_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_forward_bf16_multiblock_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +void add_device_batchnorm_forward_rank_4_3_bf16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_forward_bf16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_forward_bf16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..073dd583f97e0eed167e4bc4ef952530e82854a2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_forward_f16_blockwise_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_forward_f16_multiblock_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +void add_device_batchnorm_forward_rank_4_3_f16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_forward_f16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_forward_f16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be63bd44c66c3204c66990ff6048bd895841c84f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_forward_f32_blockwise_instances = std::tuple< + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_forward_f32_multiblock_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +void add_device_batchnorm_forward_rank_4_3_f32_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_forward_f32_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_forward_f32_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe87091e8d66e32af0832c1b013abce0911673b2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_forward_f64_blockwise_instances = std::tuple< + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_forward_f64_multiblock_instances = + std::tuple < + // XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl, + DeviceBatchNormFwdImpl + >; +// clang-format on + +void add_device_batchnorm_forward_rank_4_3_f64_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_forward_f64_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_forward_f64_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ffd6a6a7be2081429508fbbaec0edcf6bb1dbf39 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_contraction_bilinear_instance + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ebbff88346b9dcb1eb4716b1b36f456b806ef1ca --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..980383f3e71b3f1058c295c9aba55c56f596134a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d4b6e3489a9ccce619b0a76dc097f04d72b0d57 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7caa469f54b0169449250f000df038820e35a32f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ad6605486caa7fb4e60e717db6756f52f453bd8 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_contraction_scale_instance + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5118d0d033ee733c9b3bd7bc3524824c77aae667 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..655d4f0061ae1c417e610e9059f7e2bbd2048363 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a9d20be18bf5c52044b12f84c58f7824f8ef8142 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a68f5c9718a1e547c62f16888a4fd0a0ed1d8197 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..75a367076191c6b045ea96d282a2d237deba642f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_conv1d_bwd_data_instance + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a5c8384227150dd8717433dce6c9ea2395896bd --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0f3d6199f8939bb0f1ce7f8efc48f2a92cfd3c6 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30537d9373b8175207b267666574033cae4e1a3c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..190c39b870b3cb33e9080c8b591884f4ec387347 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NWC = ck::tensor_layout::convolution::NWC; +using KXC = ck::tensor_layout::convolution::KXC; +using NWK = ck::tensor_layout::convolution::NWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 1, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..281453b586ba9b635063a3b51642066450c253a5 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,10 @@ +add_instance_library(device_conv2d_bwd_data_instance + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp + + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e14cd558628772991da9b9d4ebb3dd5d852e3285 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 8, 2>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f001b83c1711a7185d71f5b2c7f795005b35b696 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 8, 1>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 8, 1>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83ba6a1c6bb8a3098da1b31a59a1d4a12b8b062e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +using AccDataType = int32_t; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 1, 8, 4>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Dl< 2, InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 1, 8, 4>, S<16, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1da9a81d904b3e43aaf5ace4f612c43c05e66ee0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +// FIXME: retire dedicated 2D version +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +using device_conv_dedidecate_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, + device_conv_dedidecate_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7c33df5e7687ee3e279b350bec9ca17902e84231 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +// FIXME: retire dedicated 2D version +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if !CK_WORKAROUND_SWDEV_325164 + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5f8629f2daef05ca1dbcbec2e8c4c668bf96a84 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +// FIXME: retire dedicated 2D version +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8076d6d35669114d84b922a59948b888dcdab506 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +// FIXME: retire dedicated 2D version +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 2, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +// FIXME: retire dedicated 2D version +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b646852fc5c6c71d963000681b335f909a10286 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_conv2d_fwd_instance + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33503b9f8ae00a92c25b178629736c7e8a5474c3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5e4bd199ea8322312af832b63a0ea559109131a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f43d13e309333bc86abf22465b7f035c1445321b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ce6b04c4218190ccd3cad238f1bc90ec3002df9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76ab3189d7f210446b305a3daec80a2902de63ff --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..670cd94fc9f988527be6bd49b22740b9de67fb09 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,3 @@ +add_instance_library(device_conv2d_fwd_bias_relu_instance + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8c255088877e95e2cbbbacbfba6f9e1f4fc0291 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto MemorySet = ck::InMemoryDataOperationEnum::Set; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..68d5f582fdc01deb6984e1b56dfe64d20291aa1f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_conv2d_fwd_bias_relu_add_instance + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe7152471ed54aa5334e820f993232912013c6de --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +// arbitrary conv +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// 1x1, stride 1, pad 0 +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +// Odd C +using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< + // clang-format off + //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); + add_device_operation_instances( + instances, + device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..db92208fd7beb5c3512d35a0b4a751992ed5f69b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_conv3d_bwd_data_instance + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04ce7c07639c968d90a8053f4b67b1335acc9b69 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0251d9157f3b894208b17091ab6c79eec8bf605d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2975727e48385207678abdcbc1fda776e761664 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc86d7302450a01263dc07d8f895c669795f3644 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using NDHWC = ck::tensor_layout::convolution::NDHWC; +using KZYXC = ck::tensor_layout::convolution::KZYXC; +using NDHWK = ck::tensor_layout::convolution::NDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //##############################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##############################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##############################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConvNdBwdDataNwcKxcNwk_Xdl< 3, int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..47516b4162029431cf19386d89e45668afbed81f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt @@ -0,0 +1,3 @@ +add_instance_library(device_elementwise_instance + device_normalize_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baddecf6451dbc58f5ee46a2d5fc5fcf754d609e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using inputType = F16; +using MeanType = F32; +using SquareMeanType = F32; +using GammaDataType = F16; +using BetaDataType = F16; +using outputType = F16; + +using Normalize = ck::tensor_operation::element_wise::Normalize; +using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple< + // clang-format off + //###################|| | functor| NDim| MPerThread| | | + DeviceElementwise, Tuple, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >, + DeviceElementwise, Tuple, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >, + DeviceElementwise, Tuple, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >, + DeviceElementwise, Tuple, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> > + // clang-format on + >; + +void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( + std::vector, Tuple, Normalize, 2>>& + instances) +{ + add_device_operation_instances( + instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c7cc2cd31288b2f2aaec0dc49f92165b912c225 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt @@ -0,0 +1,3 @@ +add_instance_library(device_elementwise_normalization_instance + device_elementwise_normalization_f16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f15372ed91b7adbb4cbec369613f6c4c500d59e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Add = ck::tensor_operation::element_wise::Add; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +// clang-format off +using device_elementwise_normalization_f16_instances = + std::tuple < + // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 1024, 1, 1024, 1, 32, 1, 8, 1, 8, 1, 8, 8>, + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 2, 1, 2, 1, 2, 2> + >; +// clang-format on + +void add_device_elementwise_normalization_rank_2_1_f16_instances( + std::vector, F16, F16, F32, F16, Add, Pass, 2, 1>>>& + instances) +{ + add_device_operation_instances( + instances, device_elementwise_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e20d592c84ed8da549f29160c87f65579307096a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -0,0 +1,43 @@ +add_instance_library(device_gemm_instance + device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2f18e14e37be6b1c5528e8f43f23d1b6c88996 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01e3b3793a487fd5eb6b6b5a6ca6a10e86bd0588 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..804e86a066ffa47c6d96de81cd4bef5644a5b131 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..159fa90f7432248721f3bfef606f6df3b04584e9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8e7798438dcaaa9d7c2e48e8074a2f58a635438 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0034ac59c3888d8dcc1b2bf8e136ac4f02c2e6ad --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b540b8b3492c355d75adf98cacc1826417db0b9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f6ff5111b1fa5dbf39e4c60d1770a319aa9dbba --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4208245e5756ccf3a2e6be29c3a5c5f73e97acc --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06fab7f68a2efef857107e3ef6038b4505e21712 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6d72fa221f4b946f3a25363a99b737a8a830428 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67d2e3ce4b41c097fffb3099b97a8d398b44b743 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..03eebf4ec3062cf5f85486288673fb96fb2761b1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | Version| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d8de04cd9fa66931dde4ed0494fc429659506c3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b12b7cf1f42da6d8c321c1ac121166753b7ff59 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..730ffd4633de3c72cb6ab932126a18d8798418dd --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..619473ff0ce40a5150c59efdfd62bc905d7ce83f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e06f9d26b4d3cf8a22259a54e103c7b8ce76190 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9458b7483c5c29011b4aa063061436da23f9e36 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77a03b746cc12e5f8cea8772c4432c9fb4e24110 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef8d7d4e40ef4de46857db4932e109780e52f044 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb65cc7b68b359ec015af3adcd3a7866983d492b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b1014ed827d08f275b5da4c20abe14f153b9950 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6f6add8bf45e98f6e9607b48a924e00ef642323 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80b3d03da26c49136c8220f8a433d8798ad751e1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93b3df1e572604e70d059c6cb622f8143eb2a92a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f10365d892455e3e36f5e05ec94b7a4f577c1fbc --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7a9eb62cef20c05cee9ff76e5d1c0b21fb95ab9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fb45b00365896be88e6aff214922c51325a4d65 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..18a78674e7ac4d79e6da3b035c8e9ea966618f16 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cef6070af8a72a6cdb0616574736e722094f3a74 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1be70d6ca414f785ec2582fd90bc5d3d2617e4e3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b8455ffa93fcceaf9d0e5d5595197f8bafddcbe --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9e28e3d7ef70b2e3aa785b048280fb4959ce312 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b1a5a57bb8b22b9235fdd02fe3f9da659b13c0d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..301d3b55b59c0a6818b27570fe1c9a7b9ff44946 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd16f35ff233fb04c79d50888f54b806e79e611a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39166698473dbd05010f960e82366012ffaa1863 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a623034ef65a1cfd0d220b9ad2253ab0b298620 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ef8d08de90a65ae04f552fc87447e3b7887bc92 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9557bae893e1affc0f2138aaa9cfca206e267d1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bbf81a5fa25d096a87c9ece48f117e38ac70550a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_add_add_fastgelu_instance + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..463e0865c0ae6980c534005d1ace62750645a74c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b71ff1b9972a97276a0fedf0106814a54ba19a0c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9060c9b1b084512f038ccd5c5b0c3573ac32f0a9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81cf01d6a9d70346e8ea2387220fb03258e6e580 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0beb10e37978cf64d51a9463ca9c01c278a51273 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_add_fastgelu_instance + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4da85cc46eb4ded15bcd63dd7b0d89864637d2b0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0) +// outout: e[m, n] +// input: a[k, m], b[k, n], d0[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab83e4baabf80242bfbf569fb13b5d39a3686ad7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4cd3fadbe93f4231388e9efccbe28a2c69a6f1b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..207e76ffe5f18f9a4d64a0083c27829933e7a5fa --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ccada3a85ebf1b84efa685ab49d4cc97f1c39e4d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_bias_add_reduce_instance + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8747af4828a2492f6db8e23b575de4efb7ce90e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed54c3a9bf58346a8829f3d7f275c6668e00ed1d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da7eae637bfb16580b4ec98e479768c2649adad7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34345095e0c2489f44437ea58f58b2891cc91ff1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cb1b3a486fd69dbcf928ddae31bc9a82a0465683 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_bilinear_instance + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55461dfba783de45a2991836535b3b652cc49a44 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n]) +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..405e69975ca9ffec97999bce38fd73d953e139ed --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[k, m] * b[n, k], d[m, n]) +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K Padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9af31b3a129d2d1bf1ff6f32d0ee8e252cbdfa29 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // no padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..841b7a1d47ce48b8486f0af04ff80571e4601d1a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[n, k], d[m, n]) +using device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + // no padding + // N % 8 == 0 && K % 8 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // M/N/K padding + // N % 8 == 0 && K % 8 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // M/N/K padding + // N % 4 == 0 && K % 4 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // M/N/K padding + // N % 8 == 0 && K % 1 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> + + // clang-format on + >; + +void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..17d27ab150f520e5405274171f9433d9e8cc5dda --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_fastgelu_instance + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f7f643beb6832594bdc393777e5e08663a1f94e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b)) +// outout: e[m, n] +// input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8e9f35d240da45fb902a2d02cd73d2eab5167e1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b)) +// outout: e[m, n] +// input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f804d45a560ac3b7b519cb3ab1840214e392d39 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b)) +// outout: e[m, n] +// input: a[m, k], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..60cb138f5654bb6f2b10ea5ba0a03191a49906ab --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b)) +// outout: e[m, n] +// input: a[m, k], b[n, k] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +// irregular tile size +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b2cf8c774a2e8f3b4cfc1e211cc20475d4100ba --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_reduce_instance + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..59e2b2da8641b9fb094750077d760e9f611d9c98 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb09bf8b8e82f6c2056881f37b9e65c73657f0e5 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a3b566de64505adb19f75ffe5e126cfed5da438 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b17e47b1cce153654d74ca7497b5c04117071de --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b336227465ff2602e13e94c9889a7eb0f8665cc --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -0,0 +1,10 @@ +add_instance_library(device_gemm_splitk_instance + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e178d3b0adeb563c36d6a8f1f3936d048df23f75 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52be9fe709bd710f3e05f381e342deef5c7affae --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b5ff404845d5ae00b5664ee44a405e979ce52d2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fc35c4198eb49bf2625f6cfa8eb80a001fdc086 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f27b2199e0c9dcb1e91566cb0aa0a05edaef5b0e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9a1095570aaf33aa8f3f93710cbb6d3ae276430 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44e5f597d0d714438c97a9ce7d591736f4fe7f48 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3a9063f7fb847b9c303ea836f97ed4cf3173ab9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3808e0248f4dc12599154880ea20bcdbb453858a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -0,0 +1,5 @@ +add_instance_library(device_grouped_conv1d_bwd_weight_instance + device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp + device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05ba449246e8c6f68182eb707a3e9da0974e2078 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a610a747cc9cc7555bd18b17fd499d19cd7c439 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances{}); + add_device_operation_instances( + instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90e074f052c6f935bcbf862250670289e0634617 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances{}); + add_device_operation_instances( + instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d90593e37708f8c21e2ae1701099fa4cf41da9e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_grouped_conv1d_fwd_instance + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74aebf1031d5b02db54fece32fa693e0c532287d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] +using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..361ea8f4ee9ff8144fd2392754fdaa53f380e40b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] +using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3145b716402e4e5cf8e542dc991bc7db0a655e74 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] +using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cde93f902c95d1c28f85b7a7627089fe3e97f841 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNWC = ck::tensor_layout::convolution::GNWC; +using GKXC = ck::tensor_layout::convolution::GKXC; +using GNWK = ck::tensor_layout::convolution::GNWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] +using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances = std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b2968d48efd3e0061c36df1febb290b64264f73 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -0,0 +1,3 @@ +add_instance_library(device_grouped_conv2d_bwd_data_instance + device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d604d42cc3a9fcbf8e9750ab5e3b2064f0bc161 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +using device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< + // clang-format off + // 1. Default + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // 2. Filter1x1Stride1Pad0 + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4009121e7fb12084e94308e3a1b1070ce9f72b11 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_grouped_conv2d_bwd_weight_instance + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +) + diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ede21f1f4f7124156540518bf4252d3526929606 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99e556618c3bfc535807ae8595a33e0f6b6d5390 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15871a28c3a2a85c053fc3259f9de67b5694ede6 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ef1b6866efce4ade692d86e50a3240fdd6ac0ae --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -0,0 +1,13 @@ +add_instance_library(device_grouped_conv2d_fwd_instance + # GNHWC, GKYXC, GNHWK + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp + # NHWGC, GKYXC, NHWGK + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + #dl + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc18b3c7358c44b3bdbf00db99274f432beefea4 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using OutDataType = ck::half_t; + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto Filter1x1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto Filter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances = + std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances{}); + + add_device_operation_instances( + instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances{}); + + add_device_operation_instances( + instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..648b39637ebe7a49d42c5dd2620834b76a04cc9b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using OutDataType = float; + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto Filter1x1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto Filter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances = std::tuple< + // clang-format off + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances = std::tuple< + // clang-format off + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances = + std::tuple< + // clang-format off + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances{}); + + add_device_operation_instances( + instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances{}); + + add_device_operation_instances( + instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cb5d06999512687cc1ef0e9fee1444096bcf1c2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using OutDataType = int8_t; + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto Filter1x1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto Filter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances = + std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances{}); + + add_device_operation_instances( + instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances{}); + + add_device_operation_instances( + instances, + device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29f3310314548b332d6fab41f7ddcebb215775eb --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +using device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // OddC + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a4a3d2a45e2a7c11cedab653bf4bed447537115 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // OddC + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fec35fd9fffa59afe9c7bdc81e3dd7ada2dde8f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..59b012134058cbd64a3548ea3892cd477660d4d9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8aca730433060d045e9c2045e6a366b08e0fe5cf --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // OddC + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..04cad43e75e6909e32cedc925140519708babe42 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -0,0 +1,5 @@ +add_instance_library(device_grouped_conv3d_bwd_weight_instance + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e48db4a5314528bb45f4444f93da4ba8872e86d8 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1655850ec148568734cbdbace8eefe8cb6487e0c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aba46b7ebeb90afb8881fc7ad4d8a2e1281eb146 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances{}); + add_device_operation_instances( + instances, + device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..78eedca5f7687d7adad6caa0dbb91da84550cd94 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_grouped_conv3d_fwd_instance + device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4ae8b6ce5f2ba3d8f3569449934d84fc631252c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..061674bd829840def99f4dcf317886844d398fea --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed7e5476760983e8ca5a17475585ec5eb9e834f2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf5fa306013e485e6dd57faa011bee8217001148 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using GNDHWC = ck::tensor_layout::convolution::GNDHWC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances = std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..82beb2ace282742a17a14881b0520677b1324c74 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_grouped_gemm_instance + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f2097b07bd4d767eb8d5fd4d6780f445d2db99e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[k, m] * b[k, n] = e[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..677bd1a2c643993c4617bc3e29619d2df5dc7665 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[k, m] * b[n, k] = e[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95a1a87d76337be22ea4c4335eeb2398ba774f3c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a103406d53b5321d66a00474611d7630936855e1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[n, k] = e[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa0cc1148051c8000b588cc72d0f15f38a412112 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_normalization_instance + device_normalization_f16_instance.cpp + device_normalization_f32_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8994d9dcb6e799509e521cbbd7d77f9ba5984fd2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +// clang-format off +using device_normalization_f16_instances = + std::tuple < + // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // fallback kernel + DeviceNormalizationImpl, // fallback kernel + DeviceNormalizationImpl, // fallback kernel + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + >; +// clang-format on + +void add_device_normalization_rank_2_1_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +void add_device_normalization_rank_4_3_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +void add_device_normalization_rank_5_3_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_normalization_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a7e1fd0b948b4f64e8134893e6e913a06af2800 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +using device_layernorm_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // fallback kernel + DeviceNormalizationImpl, // fallback kernel + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl + // clang-format on + >; + +void add_device_normalization_rank_2_1_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_layernorm_f32_instances{}); +} + +void add_device_normalization_rank_4_3_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_layernorm_f32_instances{}); +} + +void add_device_normalization_rank_5_3_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_layernorm_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f826afd68074d0e73aeb19ebfdc672aa26fcb04 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_quantization_instance + device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp + device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp + device_conv2d_xdl_perchannel_quantization_int8_instance.cpp + device_conv2d_xdl_perlayer_quantization_int8_instance.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e87e9875939e578dafcdaf973db9909f32894ec7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_bias_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06eed760149de52e26e351170d2cd310c4698e26 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_bias_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6904e269f94af8c53298f99c62f205a1ca1cfcf9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using GK = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; + +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; +using F32_Tuple = ck::Tuple; +using I32_F32_Tuple = ck::Tuple; + +using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; +using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr ck::index_t NDimSpatial = 2; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +template +// clang-format off +using device_conv2d_int8_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16> + >; +// clang-format on + +// for conv + multiple of 32 bit Ds. bit of Ds will affect the ScalarPerVector of C +template +// clang-format off +using device_conv2d_int8_32Ds_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8> + >; +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f1aa0c5c77afac489e571b7b5dad79e75c37efa --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83435d8119a7beae3bc6964f3ce8514bf7e8000b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); +} + +void add_device_conv2d_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..31ae7226f4739b5cf7890e1e693da5ab4ea6911a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt @@ -0,0 +1,76 @@ +add_instance_library(device_reduce_instance + device_reduce_instance_blockwise_f16_f16_f16_min.cpp + device_reduce_instance_blockwise_f16_f16_f16_max.cpp + device_reduce_instance_blockwise_f16_f16_f16_amax.cpp + device_reduce_instance_blockwise_f16_f32_f16_add.cpp + device_reduce_instance_blockwise_f16_f32_f16_avg.cpp + device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp + device_reduce_instance_blockwise_f32_f32_f32_add.cpp + device_reduce_instance_blockwise_f32_f32_f32_avg.cpp + device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp + device_reduce_instance_blockwise_f32_f32_f32_min.cpp + device_reduce_instance_blockwise_f32_f32_f32_max.cpp + device_reduce_instance_blockwise_f32_f32_f32_amax.cpp + device_reduce_instance_blockwise_f32_f64_f32_add.cpp + device_reduce_instance_blockwise_f32_f64_f32_avg.cpp + device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp + device_reduce_instance_blockwise_f64_f64_f64_add.cpp + device_reduce_instance_blockwise_f64_f64_f64_avg.cpp + device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp + device_reduce_instance_blockwise_f64_f64_f64_min.cpp + device_reduce_instance_blockwise_f64_f64_f64_max.cpp + device_reduce_instance_blockwise_f64_f64_f64_amax.cpp + device_reduce_instance_blockwise_i8_i32_i8_add.cpp + device_reduce_instance_blockwise_i8_i32_i8_avg.cpp + device_reduce_instance_blockwise_i8_i8_i8_min.cpp + device_reduce_instance_blockwise_i8_i8_i8_max.cpp + device_reduce_instance_blockwise_i8_i8_i8_amax.cpp + device_reduce_instance_blockwise_b16_f32_b16_add.cpp + device_reduce_instance_blockwise_b16_f32_b16_avg.cpp + device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp + device_reduce_instance_blockwise_b16_f32_b16_min.cpp + device_reduce_instance_blockwise_b16_f32_b16_max.cpp + device_reduce_instance_blockwise_b16_f32_b16_amax.cpp + device_reduce_instance_threadwise_f16_f16_f16_min.cpp + device_reduce_instance_threadwise_f16_f16_f16_max.cpp + device_reduce_instance_threadwise_f16_f16_f16_amax.cpp + device_reduce_instance_threadwise_f16_f32_f16_add.cpp + device_reduce_instance_threadwise_f16_f32_f16_avg.cpp + device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp + device_reduce_instance_threadwise_f32_f32_f32_add.cpp + device_reduce_instance_threadwise_f32_f32_f32_avg.cpp + device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp + device_reduce_instance_threadwise_f32_f32_f32_min.cpp + device_reduce_instance_threadwise_f32_f32_f32_max.cpp + device_reduce_instance_threadwise_f32_f32_f32_amax.cpp + device_reduce_instance_threadwise_f32_f64_f32_add.cpp + device_reduce_instance_threadwise_f32_f64_f32_avg.cpp + device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp + device_reduce_instance_threadwise_f64_f64_f64_add.cpp + device_reduce_instance_threadwise_f64_f64_f64_avg.cpp + device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp + device_reduce_instance_threadwise_f64_f64_f64_min.cpp + device_reduce_instance_threadwise_f64_f64_f64_max.cpp + device_reduce_instance_threadwise_f64_f64_f64_amax.cpp + device_reduce_instance_threadwise_i8_i32_i8_add.cpp + device_reduce_instance_threadwise_i8_i32_i8_avg.cpp + device_reduce_instance_threadwise_i8_i8_i8_min.cpp + device_reduce_instance_threadwise_i8_i8_i8_max.cpp + device_reduce_instance_threadwise_i8_i8_i8_amax.cpp + device_reduce_instance_threadwise_b16_f32_b16_add.cpp + device_reduce_instance_threadwise_b16_f32_b16_avg.cpp + device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp + device_reduce_instance_threadwise_b16_f32_b16_min.cpp + device_reduce_instance_threadwise_b16_f32_b16_max.cpp + device_reduce_instance_threadwise_b16_f32_b16_amax.cpp + device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp + device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp + device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp + device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp + device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp + device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp + device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp + device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp + device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp + device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1909183a55c72fec95250d29a34f80c0d36b1758 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec302010219b84d6109bd7eae209d6f9fef1ff0d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89f3e582802f0d030daf734e11fc8811640cafb1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1bdd1927b1fe035e1393fb5ec09856db0cecaa2 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58e9c5622951dd38cf49eecb013a04abfb7f610a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5012c651aa28e7e4ba78c8caddcbd59c602c8c3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0970cb9d7c2769ed87d96b7320fef43992e4f6c0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ee179a5117a21324c1a105fbf67f5df45da6696 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e53b403065490e0148c8ce47107dd892d13bd54a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cab5738fbae05b7e2ee68a1e2ca937c94601f2ea --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d2a4fad2a85805b177575477c6ea3c69a93effe --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e08b64f8b37d8b10d799aaeceeb19a7bccd9df4a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89cabf3762355c7c235e36e3e8701acdb7c95c77 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e602c121d0c9d502cd01dd862a39ed695854a22 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..489b4bc452f6017bee956517ebd9b2d2047ab182 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04e2c5b164fb5a803336c90154f2e938c421e29c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c0e53604851ab08808e23e75ca57c9887053725 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..899dfcd37c1b39b3d62b1c0a623f4edbab68a4c4 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5624337a47735470d4a78c1280afda20067be8de --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f3067ce291465a71656eed576986877b7496de7 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2648e7d59dba49520a1e87870ebce49b4197f9de --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f67ae2ee7c6c5b74ef6509a467e6864e16d05fa4 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f8e07851dfcfdb8f4aa51d4aa030b65d527d9fb --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69fecf72f510b6a5246842ba15f886c6d9d842b6 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..129a4f0f0e87a4de66d072596c0cb7ff23af3caa --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21babc4aa6315d59705c55de78410e79a9411a97 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b85b3e2b68e769367cd955d086249b2bdd27f79c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24a8293b5dda85cd8876c28373710b4398f5af0b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73e60fa959ee4654958a93a230e584e02b4e3329 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72e649d8971ff220caeca22b63102da0417039c4 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7e053a06564dc90acbc5540e998b8063496c92c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e3abd35b46f9525fdb026a7bec6480555a70c43 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +template void add_device_reduce_instance_blockwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b32456074f09313864fb3f1829e046478fab27d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3298587a42f30db880f0180605e47d50189fc33b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..729d4fd6e19b470fa47f8a33ca5336c20441ca30 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3e36e312ba1607bd0707bac69cc3ab2fa137adc --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7580e7d7dc3b030f6bb5f27b5c68a8b04be1ce8 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e6feb0071fe6b2798da98b231ac2168efcbb067 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..669c4d34ca751ebb7550a4ee447a03f7bc582f00 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..335a5474ce8985f17b5f8d4444b5af1020a150db --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e95e8391a27978f27709aed71f6fb53ac3b324d6 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25498158a2aa715c0cc73688b7ea6ef96bbf60a1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +template void add_device_reduce_instance_multiblock_atomic_add(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7262b8a5ba6342e4bc1160aa3edc1cdf1856ba6e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c526a74f1a9a8630c8fecdd2e5da440d296621e3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c7252e742d2ba23272a9378337522c6b09eb703 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..618900a7d75c6e2c478abc6cad9859660deab702 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce747cbc76473a91983728c0838f4abd23237fca --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06f622b9e69a91e0f58a587497e34a56f92d287f --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..708eb58d4041edc70c0b245160637d39dbbc30d3 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8a62fa1496a5a981e4f82661c403893286244cf --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce2092153cc9516895a1315718f88bac806f6775 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29251a8b9a6baebec2c104ef920e814cea82ecb0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..734fa9fd3e1494f6447a54badae8d9c87dcce0db --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7a0e2bfe8932a4566640c014c1f3db9b051a8a6 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b97f3008b8b0eff0cee4ff41e73d55e5cd66e85 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53d01e38d60c93ab989bc5d2be940d7af3c19081 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..125d054f3dcf21f668110ae0c792138f335d0420 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb86a2bbe443b318d04889c9b967244926cc6e91 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49af08390ad7de3dfc94bde69faaa1f0f374d294 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30cc1b13ecab4c224d7e01f734d3079e7517811a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24f8a9ba5cf22f30a620e3406fe923b20f86a65d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a26702f053cc98ab67ddb50b15743bc312f82dd4 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34fe32628fd5b4e83205204c86599c2ef8c07e64 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..74b15eddbacb01fa73d70b3708853c51d9c58600 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65762492f765b103682cfc4f9b30b23580fa5c5e --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e74295a0dc60c893cde4eb6fff9ac7704365598 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6fdea6cc4df1e6fd112382241cfa491ae2bf2cfc --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..317d573dac5352df541f559cfb6979a9be440efb --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29f95ebcc7dfb3dcdd47c998b83b07f4d24a1b77 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa9f47cbc44fe8e59218425089784d92b14e31b9 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54a9dd1ab7e333326c841d42b6852abe5a82ec35 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ef5717b5e18edf3283d4da4f5851cace4c59005 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..140a3c197b0606357da23e6a065aa86143060ca0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp new file mode 100644 index 0000000000000000000000000000000000000000..317b4ad39c0b014b0f04f88920bde86dc2af833b --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +template void add_device_reduce_instance_threadwise(std::vector>&); +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fc13261a6a74c49594240969cdae03ee4f5e5b7c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt @@ -0,0 +1,26 @@ +add_instance_library(device_softmax_instance + device_softmax_i8_i8_instance.cpp + device_softmax_i8_i8_instance_rank3_reduce1.cpp + device_softmax_i8_i8_instance_rank3_reduce2.cpp + device_softmax_i8_i8_instance_rank3_reduce3.cpp + device_softmax_i8_i8_instance_rank4_reduce1.cpp + device_softmax_i8_i8_instance_rank4_reduce2.cpp + device_softmax_i8_i8_instance_rank4_reduce3.cpp + device_softmax_i8_i8_instance_rank4_reduce4.cpp + device_softmax_f16_f16_instance.cpp + device_softmax_f16_f16_instance_rank3_reduce1.cpp + device_softmax_f16_f16_instance_rank3_reduce2.cpp + device_softmax_f16_f16_instance_rank3_reduce3.cpp + device_softmax_f16_f16_instance_rank4_reduce1.cpp + device_softmax_f16_f16_instance_rank4_reduce2.cpp + device_softmax_f16_f16_instance_rank4_reduce3.cpp + device_softmax_f16_f16_instance_rank4_reduce4.cpp + device_softmax_f32_f32_instance.cpp + device_softmax_f32_f32_instance_rank3_reduce1.cpp + device_softmax_f32_f32_instance_rank3_reduce2.cpp + device_softmax_f32_f32_instance_rank3_reduce3.cpp + device_softmax_f32_f32_instance_rank4_reduce1.cpp + device_softmax_f32_f32_instance_rank4_reduce2.cpp + device_softmax_f32_f32_instance_rank4_reduce3.cpp + device_softmax_f32_f32_instance_rank4_reduce4.cpp +) diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..14d2764529c826587b13661e7a7cab11e8d6ea99 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f16_f16_rank3_instances( + std::vector>& instances) +{ + add_device_softmax_f16_f16_rank3_reduce1_instances(instances); + add_device_softmax_f16_f16_rank3_reduce2_instances(instances); + add_device_softmax_f16_f16_rank3_reduce3_instances(instances); +} + +void add_device_softmax_f16_f16_rank4_instances( + std::vector>& instances) +{ + add_device_softmax_f16_f16_rank4_reduce1_instances(instances); + add_device_softmax_f16_f16_rank4_reduce2_instances(instances); + add_device_softmax_f16_f16_rank4_reduce3_instances(instances); + add_device_softmax_f16_f16_rank4_reduce4_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa334b997c27af20f05c42335bde28f5f32f819a --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f16_f16_rank3_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c9d37d8483b7f249a0b1c05551c70acdb46db03 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f16_f16_rank3_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fbdab5055edb0f124bbc6541fcc705a12b71d63 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f16_f16_rank3_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7dd8640b187a2b793d5ad6931ab329ae024ca432 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f16_f16_rank4_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b32fe6838f83440e98f524c59af9a383760dfe7d --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f16_f16_rank4_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c05048ec567bd6471fca14c41d062eaf8c0c86a5 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f16_f16_rank4_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a235708bd426c7270f16246fa058eae45cdd056 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f16_f16_rank4_reduce4_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5bec5e2639d234b3944f679e99e988f20b2c383 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_f32_f32_rank3_instances( + std::vector>& instances) +{ + add_device_softmax_f32_f32_rank3_reduce1_instances(instances); + add_device_softmax_f32_f32_rank3_reduce2_instances(instances); + add_device_softmax_f32_f32_rank3_reduce3_instances(instances); +} + +void add_device_softmax_f32_f32_rank4_instances( + std::vector>& instances) +{ + add_device_softmax_f32_f32_rank4_reduce1_instances(instances); + add_device_softmax_f32_f32_rank4_reduce2_instances(instances); + add_device_softmax_f32_f32_rank4_reduce3_instances(instances); + add_device_softmax_f32_f32_rank4_reduce4_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57d3f184a6635f4febe632054c9e39798ef0cc82 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f32_f32_rank3_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fae3a4dd6662d8920b1b88c8489e9a6f00f28ed0 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f32_f32_rank3_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6fb70e8e2a6aa097f2a76865f690812dffaffdd --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_f32_f32_rank3_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33c7b6f35f351f755a50f48f92d391fd47f40163 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f32_f32_rank4_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c22aa574b1f984d5a4f8c3276c9fb11b74002680 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f32_f32_rank4_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55f3d2bd207a0a160f00415592f0106f41be5cee --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f32_f32_rank4_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb0bcf5ee8a2dd6fd5ed4f77c0d8f4be5c5fcacd --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_f32_f32_rank4_reduce4_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..608cfcf8380be29e24d5e4b3aadd7573285a5224 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" + +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_softmax_i8_i8_rank3_instances( + std::vector>& instances) +{ + add_device_softmax_i8_i8_rank3_reduce1_instances(instances); + add_device_softmax_i8_i8_rank3_reduce2_instances(instances); + add_device_softmax_i8_i8_rank3_reduce3_instances(instances); +} + +void add_device_softmax_i8_i8_rank4_instances( + std::vector>& instances) +{ + add_device_softmax_i8_i8_rank4_reduce1_instances(instances); + add_device_softmax_i8_i8_rank4_reduce2_instances(instances); + add_device_softmax_i8_i8_rank4_reduce3_instances(instances); + add_device_softmax_i8_i8_rank4_reduce4_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15552dbae5d501c506f4fd14b439a80664a95b66 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_i8_i8_rank3_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67674028860b471888bc87a59e1b7ae751c64eec --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_i8_i8_rank3_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b33da93c2e1330485a55f8ede6654ab2a03fbb1 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 3; + +void add_device_softmax_i8_i8_rank3_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe3b823e889267f69001e15066bbef201afa9813 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_i8_i8_rank4_reduce1_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ecdf87d9fec061094f83c3beba01497dcb8e5b8 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_i8_i8_rank4_reduce2_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3563135204085ff8e44d1378f9c4d9ffd99e7b68 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_i8_i8_rank4_reduce3_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa21a0bf8a863f433665e01e8229bd179948e1db --- /dev/null +++ b/3rdparty/composable_kernel/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr index_t RANK = 4; + +void add_device_softmax_i8_i8_rank4_reduce4_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/3rdparty/composable_kernel/library/src/utility/CMakeLists.txt b/3rdparty/composable_kernel/library/src/utility/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7f6a59eebe16f1103e68a34bd6646a0bec6a681c --- /dev/null +++ b/3rdparty/composable_kernel/library/src/utility/CMakeLists.txt @@ -0,0 +1,28 @@ +## utility +set(UTILITY_SOURCE + device_memory.cpp + host_tensor.cpp + convolution_parameter.cpp +) + +add_library(utility STATIC ${UTILITY_SOURCE}) +add_library(composable_kernel::utility ALIAS utility) + +target_include_directories(utility PUBLIC + "$" + "$" +) + +rocm_install( + TARGETS utility + EXPORT utilityTargets +) + +rocm_install( + EXPORT utilityTargets + FILE composable_kernelutilityTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) + +clang_tidy_check(utility) diff --git a/3rdparty/composable_kernel/library/src/utility/convolution_parameter.cpp b/3rdparty/composable_kernel/library/src/utility/convolution_parameter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8712d20939d5f9f1b7b981645e77420c7a5c607 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/utility/convolution_parameter.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host_utility/io.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" + +namespace ck { +namespace utils { +namespace conv { + +ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(n_dim), + G_(group_count), + N_(n_batch), + K_(n_out_channels), + C_(n_in_channels), + filter_spatial_lengths_(filters_len), + input_spatial_lengths_(input_len), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(strides), + conv_filter_dilations_(dilations), + input_left_pads_(left_pads), + input_right_pads_(right_pads) +{ + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } +} + +ConvParam::ConvParam() + : ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) +{ +} + +std::vector ConvParam::GetOutputSpatialLengths() const +{ + return output_spatial_lengths_; +} + +std::size_t ConvParam::GetFlops() const +{ + // 2 * G * N * K * C * * + return static_cast(2) * G_ * N_ * K_ * C_ * + ck::accumulate_n( + std::begin(output_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()) * + ck::accumulate_n( + std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()); +} + +std::string get_conv_param_parser_helper_msg() +{ + std::string msg; + + msg += "Following arguments (depending on number of spatial dims):\n" + " Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n" + " G, N, K, C, \n" + " , (ie Y, X for 2D)\n" + " , (ie Hi, Wi for 2D)\n" + " , (ie Sy, Sx for 2D)\n" + " , (ie Dy, Dx for 2D)\n" + " , (ie LeftPy, LeftPx for 2D)\n" + " , (ie RightPy, RightPx for 2D)\n"; + + return msg; +} + +ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) +{ + const ck::index_t G = std::stoi(argv[arg_idx++]); + const ck::index_t N = std::stoi(argv[arg_idx++]); + const ck::index_t K = std::stoi(argv[arg_idx++]); + const ck::index_t C = std::stoi(argv[arg_idx++]); + + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); + + for(int i = 0; i < num_dim_spatial; ++i) + { + filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return ck::utils::conv::ConvParam{num_dim_spatial, + G, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; +} +} // namespace conv +} // namespace utils +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) +{ + os << "ConvParam {" + << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_ + << "\nK: " << p.K_ << "\nC: " << p.C_ + << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ + << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ + << "\nconv_filter_strides: " << p.conv_filter_strides_ + << "\nconv_filter_dilations: " << p.conv_filter_dilations_ + << "\ninput_left_pads: " << p.input_left_pads_ + << "\ninput_right_pads: " << p.input_right_pads_ << "}\n"; + + return os; +} diff --git a/3rdparty/composable_kernel/library/src/utility/device_memory.cpp b/3rdparty/composable_kernel/library/src/utility/device_memory.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90f943313b0961bb96d6394e855809c62d050559 --- /dev/null +++ b/3rdparty/composable_kernel/library/src/utility/device_memory.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host_utility/hip_check_error.hpp" + +#include "ck/library/utility/device_memory.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } + +std::size_t DeviceMem::GetBufferSize() const { return mMemSize; } + +void DeviceMem::ToDevice(const void* p) const +{ + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) const +{ + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } + +DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } diff --git a/3rdparty/composable_kernel/library/src/utility/host_tensor.cpp b/3rdparty/composable_kernel/library/src/utility/host_tensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e34fbc8f345b8ca6dd5b34988a1839c6c59e61bd --- /dev/null +++ b/3rdparty/composable_kernel/library/src/utility/host_tensor.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/library/utility/host_tensor.hpp" + +void HostTensorDescriptor::CalculateStrides() +{ + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + +std::size_t HostTensorDescriptor::GetElementSize() const +{ + assert(mLens.size() == mStrides.size()); + return std::accumulate( + mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetElementSpaceSize() const +{ + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.size(); ++i) + { + if(mLens[i] == 0) + continue; + + space += (mLens[i] - 1) * mStrides[i]; + } + return space; +} + +const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + +const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + +std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}"; + + return os; +} diff --git a/3rdparty/composable_kernel/profiler/CMakeLists.txt b/3rdparty/composable_kernel/profiler/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdd7125ac1a208b9b696efbf67eff75b768cdfb7 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(BEFORE + ${CMAKE_CURRENT_LIST_DIR}/include +) + +add_subdirectory(src) diff --git a/3rdparty/composable_kernel/profiler/README.md b/3rdparty/composable_kernel/profiler/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfd6a3a53be6ad8c84bb35c97d09f6125145591f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/README.md @@ -0,0 +1,48 @@ +## Profile GEMM kernels +```bash +#arg1: tensor operation (gemm=GEMM) +#arg2: data type (0=fp32, 1=fp16) +#arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT) +#arg4: verification (0=no, 1=yes) +#arg5: initialization (0=no init, 1=integer value, 2=decimal value) +#arg6: print matrix value (0=no, 1=yes) +#arg7: run kernel # of times (>1) +#arg8 to 13: M, N, K, StrideA, StrideB, StrideC + +################ op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +```bash +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +.... +Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s +``` + +## Profile 2d forward convolution kernels +```bash +#arg1: tensor operation (conv=Convolution) +#arg2: data type (0=fp32, 1=fp16) +#arg3: input tensor layout (0=NCHW, 1=NHWC) +#arg4: weight tensor layout (0=KCYX, 1=KYXC) +#arg5: output tensor layout (0=NKHW, 1=NHWK) +#arg6: verification (0=no, 1=yes) +#arg7: initialization (0=no init, 1=integer value, 2=decimal value) +#arg8: print matrix value (0=no, 1=yes) +#arg9: run kernel # of times (>1) +#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx + ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +.... +Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s +``` diff --git a/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum.hpp b/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..afcd6fea224f3e06f29309349bcc6ad55caa7d4f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +enum struct DataTypeEnum +{ + Half = 0, + Float = 1, + Int32 = 2, + Int8 = 3, + Int8x4 = 4, + BFloat16 = 5, + Double = 6, + Unknown = 100, +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum_helper.hpp b/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d9bd5e1a4008fc5e6c09815c45408d8b07376a4f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/data_type_enum_helper.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma + +#include "ck/utility/data_type.hpp" +#include "profiler/data_type_enum.hpp" + +namespace ck { + +template +struct get_datatype_from_enum; + +template <> +struct get_datatype_from_enum +{ + using type = int8_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = int32_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = half_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = float; +}; + +template <> +struct get_datatype_from_enum +{ + using type = double; +}; + +template +struct get_datatype_enum_from_type; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Int8; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Int32; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Half; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Float; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum value = DataTypeEnum::Double; +}; + +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b16254279ce433d06f71ca350b53841a294a3de6 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int BatchCount = 1, + int StrideA0 = -1, + int StrideB0 = -1, + int StrideD0 = -1, + int StrideB1 = -1, + int StrideD1 = -1, + int StrideE1 = -1, + int BatchStrideA0 = -1, + int BatchStrideB0 = -1, + int BatchStrideD0 = -1, + int BatchStrideB1 = -1, + int BatchStrideD1 = -1, + int BatchStrideE1 = -1) + +{ + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; + + using PassThrough = tensor_operation::element_wise::PassThrough; + + using A0ElementOp = PassThrough; + using B0ElementOp = PassThrough; + using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; + using B1ElementOp = PassThrough; + using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + + using D0DataType = remove_cvref_t>; + + using D0Layout = remove_cvref_t>; + using D1DataType = remove_cvref_t>; + using D1Layout = remove_cvref_t>; + + // for reference + using RefAcc0DataType = float; + using RefAcc1DataType = float; + + bool pass = true; + + const int DefaultStrideA0 = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideD1 = ck::is_same_v ? O : M; + const int DefaultStrideE1 = ck::is_same_v ? O : M; + + StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1; + StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1; + + const int DefaultBatchStrideA0 = (ck::is_same_v ? K : M) * StrideA0; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideD0 = (ck::is_same_v ? N : M) * StrideD0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideD1 = (ck::is_same_v ? O : M) * StrideD1; + const int DefaultBatchStrideE1 = (ck::is_same_v ? O : M) * StrideE1; + + BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; + BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + // E_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a0_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor d0_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor d1_g_m_o( + f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{})); + Tensor e1_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + Tensor e1_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + + // Host verification: Output of Gemm0 is input A of Gemm1 + Tensor c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); + + std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl; + std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d0_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + default: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); + DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * + e1_g_m_o_device_result.mDesc.GetElementSize()); + + a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); + + auto a0_element_op = A0ElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto cde0_element_op = CDE0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto cde1_element_op = CDE1ElementOp{}; + + using DeviceOp = + tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + // Ref Gemm0 + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Gemm1 + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // cde0_elementwise + e0_g_m_n.ForEach( + [&](auto&, auto idx) { cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d0_g_m_n(idx)); }); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{}); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // cde1_elementwise + e1_g_m_o_host_result.ForEach([&](auto&, auto idx) { + cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); + }); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a0_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + std::array{d0_g_m_n_device_buf.GetDeviceBuffer()}, + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + std::array{d1_g_m_o_device_buf.GetDeviceBuffer()}, + static_cast(e1_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + std::array{StrideD0}, + StrideB1, + std::array{StrideD1}, + StrideE1, + BatchStrideA0, + BatchStrideB0, + std::array{BatchStrideD0}, + BatchStrideB1, + std::array{BatchStrideD1}, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = + (sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N + + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); + + pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result); + + if(do_log) + { + LogRangeAsType( + std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1583c6db21e0959bd155946b2e81363091b4f43f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_gemm_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int BatchCount = 1, + int StrideA = -1, + int StrideB0 = -1, + int StrideB1 = -1, + int StrideC = -1, + int BatchStrideA = -1, + int BatchStrideB0 = -1, + int BatchStrideB1 = -1, + int BatchStrideC = -1) + +{ + + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using Acc0ElementOp = PassThrough; + using CElementOp = PassThrough; + using AccDataType = float; + + // Ref Gemm0 + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Gemm + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + bool pass = true; + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + // Host verification: Output of Gemm0 is input A of Gemm1 + Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + using DeviceOp = tensor_operation::device::DeviceBatchedGemmGemm; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // early fail when no instances are found + if(op_ptrs.size() == 0) + { + return false; + } + + if(do_verification) + { + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + acc0_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c07d7c0555490559bbf0357a5d2657bcbe5261fd --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_impl.hpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int BatchStrideA, + int BatchStrideB, + int BatchStrideC, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{})); + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..45b7b77388b63c8abba43aabc50a9c0dd86eeede --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result({BatchCount, M}); + Tensor d1_g_m_host_result({BatchCount, M}); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result({BatchCount, M}); + Tensor d1_g_m_device_result({BatchCount, M}); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&passthrough, &passthrough}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + d0_g_m_host_result(batch, m) = ck::type_convert(reduce0_acc); + d1_g_m_host_result(batch, m) = ck::type_convert(reduce1_acc); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + d0_g_m_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + d1_g_m_device_result.mDesc.GetElementSpaceSize()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops, + BatchCount); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); + bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result); + bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result); + + pass = pass && (c_error == true); + pass = pass && (d0_error == true); + pass = pass && (d1_error == true); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", d0_g_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", d1_g_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f5ec235141a7238157c0d2fd34f1f9c849decfc4 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int BatchCount = 1, + int StrideA = -1, + int StrideB0 = -1, + int StrideB1 = -1, + int StrideC = -1, + int BatchStrideA = -1, + int BatchStrideB0 = -1, + int BatchStrideB1 = -1, + int BatchStrideC = -1, + float alpha = -1.f) + +{ + + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + using Scale = tensor_operation::element_wise::Scale; + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = Scale; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + using AccDataType = float; + + // Ref Gemm0: various type in, fp32 out + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Softmax: fp32 in, various type out + using ReferenceSoftmaxInstance = + tensor_operation::host::ReferenceSoftmax; + + // Ref Gemm1: various type in, various type out + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + bool pass = true; + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + // Host verification: Output of Gemm0 is input A of Gemm1 + Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + std::srand(1); // work around test flakiness + switch(init_method) + { + case 0: break; + case 1: + // Still unsure whether this kind of deterministic floating point accurary issue is expected + // or not. May want to try exact same approach as the GPU kernel in the host reference + // GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then, + // shrink the input value range as it is less likely to produce errors of around ~1e-3. + // a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + + if(alpha < 0) + { + alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim) + } + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemm; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // mask out upper triangle + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(MaskOutUpperTriangle && idx[1] < idx[2]) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91c28f25fc5953460a00d87961b97d1587aad1ce --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int G0, + int G1, + float alpha = -1.f) + +{ + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Scale = tensor_operation::element_wise::Scale; + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = Scale; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + using AccDataType = float; + using tensor_operation::device::MaskingSpecialization; + + // Ref Gemm0: various type in, fp32 out + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Softmax: fp32 in, various type out + using ReferenceSoftmaxInstance = + tensor_operation::host::ReferenceSoftmax; + + // Ref Gemm1: various type in, various type out + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + bool pass = true; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + const int BatchCount = G0 * G1; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + std::srand(1); // work around test flakiness + switch(init_method) + { + case 0: break; + case 1: + // Still unsure whether this kind of deterministic floating point accurary issue is expected + // or not. May want to try exact same approach as the GPU kernel in the host reference + // GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then, + // shrink the input value range as it is less likely to produce errors of around ~1e-3. + // a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + if(alpha < 0) + { + alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim) + } + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + MaskingSpec>; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // mask out upper triangle + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2]) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + {}, // std::array p_acc0_biases; + {}, // std::array p_acc1_biases; + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_lengths}, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_strides}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_lengths}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_strides}, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + pass = pass & ck::utils::check_err(c_gs_ms_os_device_result, + c_gs_ms_os_host_result, + "Error: Incorrect results!", + rtol, + atol); + + if(do_log) + { + LogRangeAsType(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_gs_ms_os_device_result : ", + c_gs_ms_os_device_result.mData, + ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_backward_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_backward_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..79d8862081fd78b0706110219dc5af5016cd9122 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_backward_impl.hpp @@ -0,0 +1,390 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batchnorm_backward_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector inOutLengths, + const std::vector reduceDims, + bool haveSavedMeanInvVar, + double epsilon) +{ + if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim) + { + throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!"); + }; + + std::vector scaleBiasMeanVarLengths; + + // used for calculating the effective transferred bytes by each operation + size_t total_length; + size_t invariant_length = 1; + + total_length = + std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies{}); + + if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + for(int dim = 0; dim < Rank; dim++) + { + if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; })) + { + scaleBiasMeanVarLengths.push_back(inOutLengths[dim]); + invariant_length *= inOutLengths[dim]; + }; + } + + // input data of the batchnorm backward algorithm + Tensor x(inOutLengths); + Tensor dy(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + + Tensor savedMean(scaleBiasMeanVarLengths); + Tensor savedInvVar(scaleBiasMeanVarLengths); + // savedVariance is only used for initializing savedInvVar + Tensor savedVariance(scaleBiasMeanVarLengths); + + // output data of the batchnorm backward algorithm + Tensor dx_ref(inOutLengths); + Tensor dx(inOutLengths); + + Tensor dscale(scaleBiasMeanVarLengths); + Tensor dbias(scaleBiasMeanVarLengths); + + Tensor dscale_ref(scaleBiasMeanVarLengths); + Tensor dbias_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(haveSavedMeanInvVar) + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.0001f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the savedMean to be values with tiny variation to the mean of the x values + savedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, + num_thread); + + // initialize the variance to be values with tiny variation to the variance of the x values + savedVariance.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + + auto it_src = savedVariance.mData.begin(); + auto it_dst = savedInvVar.mData.begin(); + float tmp_epsilon = std::numeric_limits::epsilon(); + + while(it_src != savedVariance.mData.end()) + { + *it_dst = type_convert( + 1.0f / std::sqrtf(type_convert(*it_src) + tmp_epsilon)); + + it_src++; + it_dst++; + }; + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + dy.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{-0.2f, 0.2f}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_3{-0.5f, 0.5f}, num_thread); + } + }; + + // input data of the batchnorm backward algorithm + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem dy_dev(sizeof(DyDataType) * dy.mDesc.GetElementSpaceSize()); + + DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize()); + + DeviceMem savedMean_dev(sizeof(MeanVarDataType) * savedMean.mDesc.GetElementSpaceSize()); + DeviceMem savedInvVar_dev(sizeof(MeanVarDataType) * savedInvVar.mDesc.GetElementSpaceSize()); + + // output data of the batchnorm backward algorithm + DeviceMem dx_dev(sizeof(DxDataType) * dx.mDesc.GetElementSpaceSize()); + + DeviceMem dscale_dev(sizeof(DscaleDbiasDataType) * dscale.mDesc.GetElementSpaceSize()); + DeviceMem dbias_dev(sizeof(DscaleDbiasDataType) * dbias.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + dy_dev.ToDevice(dy.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + + if(haveSavedMeanInvVar) + { + savedMean_dev.ToDevice(savedMean.mData.data()); + savedInvVar_dev.ToDevice(savedInvVar.mData.data()); + }; + + std::array arrInOutLengths; + std::array arrInOutStrides; + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + std::array arrReduceDims; + + std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin()); + std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin()); + std::copy(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + arrScaleBiasMeanVarLengths.begin()); + std::copy(scaleBiasMeanVarStrides.begin(), + scaleBiasMeanVarStrides.end(), + arrScaleBiasMeanVarStrides.begin()); + + std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + // add device batchnorm-backward instances + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceBatchNormBwdInstance = + ck::tensor_operation::host::ReferenceBatchNormBwd; + + auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; + + auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x.mData.data(), + dy.mData.data(), + bnScale.mData.data(), + haveSavedMeanInvVar ? savedMean.mData.data() : nullptr, + haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr, + epsilon, + PassThroughOp{}, + dx_ref.mData.data(), + dscale_ref.mData.data(), + dbias_ref.mData.data()); + + if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters not supported by the reference instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + } + + int num_kernel = 0; + bool pass = true; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr, + haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr, + epsilon, + PassThroughOp{}, + dx_dev.GetDeviceBuffer(), + dscale_dev.GetDeviceBuffer(), + dbias_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() + << " skipped due to unsupported argument: " << std::endl; + } + + continue; + }; + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + size_t num_bytes = 0; + + // inputing of x, dy, scale, outputing of dx, dscale, dbias + num_bytes += total_length * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) + + invariant_length * sizeof(DscaleDbiasDataType) * 2; + + // inputting of savedMean, savedInvVariance + if(haveSavedMeanInvVar) + num_bytes += invariant_length * sizeof(MeanVarDataType) * 2; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + using ck::utils::check_err; + bool single_pass = true; + + dx_dev.FromDevice(dx.mData.data()); + dscale_dev.FromDevice(dscale.data()); + dbias_dev.FromDevice(dbias.data()); + + // clang-format off + single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4); + single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3); + single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3); + // clang-format on + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + using ck::host_common::dumpBufferToFile; + + // clang-format off + dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize()); + dumpBufferToFile("dump_dy.bin", dy.mData.data(), dy.mDesc.GetElementSize()); + dumpBufferToFile("dump_dx.bin", dx.mData.data(), dx.mDesc.GetElementSize()); + dumpBufferToFile("dump_dx_ref.bin", dx_ref.mData.data(), dx_ref.mDesc.GetElementSize()); + dumpBufferToFile("dump_dscale.bin", dscale.mData.data(), dscale.mDesc.GetElementSize()); + dumpBufferToFile("dump_dscale_ref.bin", dscale_ref.mData.data(), dscale_ref.mDesc.GetElementSize()); + // clang-format off + }; + } + + if(time_kernel) + { + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_forward_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_forward_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..82fe75bf015beba4d56898f67aa2100b2d3c6ff4 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_batchnorm_forward_impl.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batchnorm_forward_impl(int do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector inOutLengths, + const std::vector reduceDims, + bool updateMovingAverage, + bool saveMeanAndInvVariance, + double averageFactor, + double epsilon) +{ + if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim) + { + throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!"); + }; + + std::vector scaleBiasMeanVarLengths; + + // used for calculating the effective transferred bytes by each operation + size_t total_length; + size_t invariant_length = 1; + + total_length = + std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies{}); + + if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + for(int dim = 0; dim < Rank; dim++) + { + if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; })) + { + scaleBiasMeanVarLengths.push_back(inOutLengths[dim]); + invariant_length *= inOutLengths[dim]; + }; + } + + // input data of the batchnorm forward algorithm + Tensor x(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnBias(scaleBiasMeanVarLengths); + + // output data of the batchnorm forward algorithm + Tensor y_ref(inOutLengths); + Tensor y(inOutLengths); + + Tensor resultSaveMean_ref(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance_ref(scaleBiasMeanVarLengths); + + Tensor resultRunningMean_ref(scaleBiasMeanVarLengths); + Tensor resultRunningVariance_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(updateMovingAverage) + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.04f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the runningMean to be values with tiny variation to the mean of the x + // values + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + // initialize the runningVariance to be values with tiny variation to the variance of + // the x values + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + } + else + { + if constexpr(ck::is_same_v) + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + else + x.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_1{0}, num_thread); + break; + case 2: + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + bnScale.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(XDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnBias_dev(sizeof(BiasDataType) * bnBias.mDesc.GetElementSpaceSize()); + + // mean_dev or resultSaveMean_dev + DeviceMem resultSaveMean_dev(sizeof(MeanVarDataType) * + resultSaveMean_ref.mDesc.GetElementSpaceSize()); + // meansquare_dev or resultSaveInvVariance_dev + DeviceMem resultSaveInvVariance_dev(sizeof(MeanVarDataType) * + resultSaveInvVariance_ref.mDesc.GetElementSpaceSize()); + // resultRunningMean_dev + DeviceMem resultRunningMean_dev(sizeof(MeanVarDataType) * + resultRunningMean_ref.mDesc.GetElementSpaceSize()); + // resultRunningVariance_dev + DeviceMem resultRunningVariance_dev(sizeof(MeanVarDataType) * + resultRunningVariance_ref.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + bnBias_dev.ToDevice(bnBias.mData.data()); + + if(updateMovingAverage) + { + resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data()); + resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data()); + }; + + // used for storing the device result for verification when updateMovingAverage is enabled + Tensor resultRunningMean(scaleBiasMeanVarLengths); + Tensor resultRunningVariance(scaleBiasMeanVarLengths); + + // used for storing the device result for verification when saveMeanAndInvVariance is enabled + Tensor resultSaveMean(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance(scaleBiasMeanVarLengths); + + std::array arrInOutLengths; + std::array arrInOutStrides; + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + std::array arrReduceDims; + + std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin()); + std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin()); + std::copy(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + arrScaleBiasMeanVarLengths.begin()); + std::copy(scaleBiasMeanVarStrides.begin(), + scaleBiasMeanVarStrides.end(), + arrScaleBiasMeanVarStrides.begin()); + + std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + // add device batchnorm-forward instances + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceBatchNormFwdInstance = + ck::tensor_operation::host::ReferenceBatchNormFwd; + + auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; + + auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x.mData.data(), + bnScale.mData.data(), + bnBias.mData.data(), + epsilon, + PassThroughOp{}, + y_ref.mData.data(), + saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, + updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); + + if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters not supported by the reference instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + } + + int num_kernel = 0; + bool pass = true; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + PassThroughOp{}, + y_dev.GetDeviceBuffer(), + saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, + updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() + << " skipped due to unsupported argument: " << std::endl; + } + + continue; + }; + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + size_t num_bytes = 0; + + // inputing of x, scale, bias, outputing of y + num_bytes += total_length * (sizeof(XDataType) + sizeof(YDataType)) + + invariant_length * (sizeof(ScaleDataType) + sizeof(BiasDataType)); + + // outputing of mean, inv-variance + num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(MeanVarDataType) * 2 : 0; + + // updating of moving mean, variance + num_bytes += updateMovingAverage ? invariant_length * sizeof(MeanVarDataType) * 4 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + using ck::utils::check_err; + bool single_pass; + + y_dev.FromDevice(y.mData.data()); + + if constexpr(ck::is_same_v) + single_pass = check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2); + else + single_pass = check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3); + + if(updateMovingAverage) + { + resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); + resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); + + // clang-format off + single_pass = single_pass && check_err(resultRunningMean.mData, resultRunningMean_ref.mData, "average mean results", 1.5e-5, 1.5e-5); + single_pass = single_pass && check_err(resultRunningVariance.mData, resultRunningVariance_ref.mData, "average variance results", 1e-5, 1e-5); + // clang-format on + }; + + if(saveMeanAndInvVariance) + { + resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); + resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); + + // clang-format off + single_pass = single_pass && check_err(resultSaveMean.mData, resultSaveMean_ref.mData, "mean results", 3e-5, 3e-5); + single_pass = single_pass && check_err(resultSaveInvVariance.mData, resultSaveInvVariance_ref.mData, "inv-variance results", 7e-5, 7e-5); + // clang-format on + }; + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + using ck::host_common::dumpBufferToFile; + + // clang-format off + dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize()); + dumpBufferToFile("dump_y.bin", y.mData.data(), y.mDesc.GetElementSize()); + dumpBufferToFile("dump_y_ref.bin", y_ref.mData.data(), y_ref.mDesc.GetElementSize()); + // clang-format off + + if(saveMeanAndInvVariance) + { + // clang-format off + dumpBufferToFile("dump_mean.bin", resultSaveMean.mData.data(), resultSaveMean.mDesc.GetElementSize()); + dumpBufferToFile("dump_mean_ref.bin", resultSaveMean_ref.mData.data(), resultSaveMean_ref.mDesc.GetElementSize()); + dumpBufferToFile("dump_invvar.bin", resultSaveInvVariance.mData.data(), resultSaveInvVariance.mDesc.GetElementSize()); + dumpBufferToFile("dump_invvar_ref.bin", resultSaveInvVariance_ref.mData.data(), resultSaveInvVariance_ref.mDesc.GetElementSize()); + // clang-format on + }; + }; + } + + if(time_kernel) + { + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_bwd_data_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86d394daf90c79d6f9faf34254901dc8e2bd43e2 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +namespace ck { +namespace profiler { + +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_conv_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + Tensor input_host_result(in_g_n_c_wis_desc); + Tensor input_device_result(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + + std::cout << "input: " << input_host_result.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + output.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(output.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input_host_result, + weight, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + } + + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool pass = true; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.output_spatial_lengths_, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // for conv bwd data, some input tensor element are zero, but not written by kernel, + // need to set zero + in_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(input_device_result.mData.data()); + + pass = pass & ck::utils::check_err(input_device_result, input_host_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weight); + std::cout << std::endl; + + std::cout << "out_host : "; + show_data_nhwc_layout(input_host_result); + std::cout << std::endl; + + std::cout << "out_device: "; + show_data_nhwc_layout(input_device_result); + std::cout << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1aebef8bb2b2b34a640797ab9e5985d5cb09fbea --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvFwdBiasReluAddPtr = + DeviceConvFwdBiasActivationAddPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + using namespace ck::literals; + + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k({K}); + + // residual: assume same layout as output tensor + Tensor resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize()); + DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + + using DeviceConvFwdBiasReluAddPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationAddPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(resi_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2bac144334eefc2d39ff63e17637b50d977921ec --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvFwdBiasReluPtr = + DeviceConvFwdBiasActivationPtr; + +void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + using namespace ck::literals; + + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + // bias: assume contiguous 1d vector + Tensor bias_k({K}); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::cout << "bias_k: " << bias_k.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + bias_device_buf.ToDevice(bias_k.mData.data()); + + using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device:: + DeviceConvFwdBiasActivationPtr; + + // add device operator instances + std::vector op_ptrs; + + if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::instance:: + add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); + } + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f3ba8f00714449e51f180879f3b38d01221cefe --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // run reference op + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + using DeviceOp = ck::tensor_operation::device::DeviceConvFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.GetOutputSpatialLengths(), + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e69ebc8bd173b25311b45305a24d71fcd7324a9 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_data_impl.hpp @@ -0,0 +1,486 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using INT8 = int8_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( + std::vector&); + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); + +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector&); +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { +using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} +template +void get_device_conv_bwd_data_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd data" << std::endl; + exit(1); +} +template <> +void get_device_conv_bwd_data_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} +template <> +void get_device_conv_bwd_data_op_ptr( + INT8, INT8, INT8, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); + break; + default: break; + } +} + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(std::size_t i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + return true; +} +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input_host_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor input_device_result( + get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input_host_result.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + output.GenerateTensorValue(GeneratorTensor_1{1}); + weights.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(output.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input_host_result, + weights, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_data_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(input_device_result.mData.data()); + + if(!check_out(input_host_result, input_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + success = ck::utils::check_err(input_host_result, input_device_result); + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights); + std::cout << std::endl; + + std::cout << "out_host : "; + show_data_nhwc_layout(input_host_result); + std::cout << std::endl; + + std::cout << "out_device: "; + show_data_nhwc_layout(input_device_result); + std::cout << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e37c887a96f9c843abd62ee83f026da06f887daf --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_convnd_bwd_weight_impl.hpp @@ -0,0 +1,474 @@ +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/conv_util.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using DeviceConvndBwdWeightNoOpPtr = + DeviceConvBwdWeightPtr; + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances( + std::vector&); +void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector&); + +void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +using DeviceConvndBwdWeightNoOpPtr = + ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr; + +template +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 2: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + case 1: { + return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +template +void get_device_conv_bwd_weight_op_ptr( + InDataType, WeiDataType, OutDataType, std::vector&, int) +{ + std::cout << "can not find device conv bwd weight" << std::endl; + exit(1); +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F32, F32, F32, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + F16, F16, F16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + break; + default: break; + } +} + +template <> +void get_device_conv_bwd_weight_op_ptr( + BF16, BF16, BF16, std::vector& conv_ptrs, int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 1: + ck::tensor_operation::device::instance:: + add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + break; + case 2: + ck::tensor_operation::device::instance:: + add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + break; + case 3: + ck::tensor_operation::device::instance:: + add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + break; + default: break; + } +} + +template +void show_data_nhwc_layout(Tensor& nhwc) +{ + std::cout << "["; + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) + { + std::cout << "["; + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) + { + std::cout << "["; + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) + { + std::cout << "["; + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) + { + std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; + } + std::cout << "]"; +} + +template +bool profile_convnd_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t split_k) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + std::vector input_dims{static_cast(N), static_cast(C)}; + input_dims.insert( + std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths)); + + std::vector filter_dims{static_cast(K), static_cast(C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths)); + + std::vector output_dims{static_cast(N), static_cast(K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, NDimSpatial)); + Tensor weights_host_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor weights_device_result( + get_filters_host_tensor_descriptor(filter_dims, NDimSpatial)); + Tensor output( + get_output_host_ensor_descriptor(output_dims, NDimSpatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights_host_result.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + output.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_1{1}); + output.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + out_device_buf.ToDevice(output.mData.data()); + + // reset input to zero + wei_device_buf.SetZero(); + + if(do_verification) + { + auto RunReference = [&](auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(input, + weights_host_result, + output, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + }; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight(); + RunReference(ref_conv); + } + + // add device Conv instances + std::vector conv_ptrs; + get_device_conv_bwd_weight_op_ptr( + InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + // using atomic, so need to reset input, setzero is done in invoker + // if(split_k > 1) + //{ + // wei_device_buf.SetZero(); + //} + + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + if(!conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + continue; + } + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + std::string conv_name = conv_ptr->GetTypeString(); + float ave_time = 0; + + if(std::is_same::value && split_k > 1) + { + // alloc work space + size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get()); + if(bwd_weight_workspace_size <= 0) + { + printf("wrong work space size\n"); + exit(1); + } + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv_ptr->SetWorkSpacePointer(argument_ptr.get(), + wei_work_space_device_buf.GetDeviceBuffer()); + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + else + { + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + } + + std::size_t flop = + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(weights_device_result.mData.data()); + + success = ck::utils::check_err(weights_host_result, weights_device_result); + + if(success == false) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + + if(do_log) + { + std::cout << "in : "; + show_data_nhwc_layout(output); + std::cout << std::endl; + + std::cout << "wei: "; + show_data_nhwc_layout(weights_host_result); + std::cout << std::endl; + + std::cout << "out : "; + show_data_nhwc_layout(input); + std::cout << std::endl; + + std::cout << "wei_device: "; + show_data_nhwc_layout(weights_device_result); + std::cout << std::endl; + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; + return success; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7707e16b089ef3058dea988188f80e453fcd4e7d --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" + +#include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" + +namespace ck { +namespace profiler { + +template +void host_elementwise2D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + { + auto a_val = A(m, n); + auto b_val = B(m, n); + ctype c_val = 0; + functor(c_val, a_val, b_val); + C(m, n) = c_val; + } +} + +template +bool profile_elementwise_layernorm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + using Add = ck::tensor_operation::element_wise::Add; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + if(length.size() != 2) + return false; + + index_t M = length[0]; + index_t N = length[1]; + index_t Stride = N; + + constexpr int Rank = 2; + constexpr int NumReduceDim = 1; + + std::vector reduce_dim = {1}; + std::vector gammaBetaLength = {N}; + std::vector gammaBetaStride = {0, 1}; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + using namespace ck::literals; + + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + }; + + Tensor a(length); + Tensor b(length); + Tensor gamma(gammaBetaLength); + Tensor beta(gammaBetaLength); + Tensor y(length); + Tensor host_y(length); + + switch(init_method) + { + case 0: + a.GenerateTensorValue(GeneratorTensor_1{}); + b.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + beta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + beta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a.GenerateTensorValue(GeneratorTensor_3{0, 1}); + b.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + beta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_dev(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_dev(sizeof(ADataType) * b.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + a_dev.ToDevice(a.mData.data()); + b_dev.ToDevice(b.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + std::array input = {a_dev.GetDeviceBuffer(), b_dev.GetDeviceBuffer()}; + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceElementwiseNormalization< + ck::Tuple, + GammaDataType, + BetaDataType, + AccDataType, + YDataType, + Add, + PassThrough, + 2, + 1>; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using XDataType = ADataType; + std::vector mn = {static_cast(M), + static_cast(N)}; + Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); + host_elementwise2D, Tensor, Tensor, Add>( + x, a, b, mn, Add{}); + + using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + length, + { + std::vector{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, + std::vector{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, + }, + gammaBetaStride, + gammaBetaStride, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + reduce_dim, + 1e-4, + input, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + Add{}, + PassThrough{}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = a.mDesc.GetElementSize() * sizeof(ADataType) + + b.mDesc.GetElementSize() * sizeof(BDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + beta.mDesc.GetElementSize() * sizeof(BetaDataType) + + y.mDesc.GetElementSize() * sizeof(YDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + y_dev.FromDevice(y.mData.data()); + + bool pass = + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b : ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_y : ", host_y.mData, ",") << std::endl; + LogRangeAsType(std::cout << "y : ", y.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, " + << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is tested" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cc2ea3b92624da3db43696c181248e55a5f0256 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_add_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddAddFastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d53a6589e0f505ed3beb6a7792c9045c49ee0c68 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddFastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b4ec78cdf37659b6e838b10a8c235fdd6a0c9c64 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmBiasAddReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>; + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_add_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideD0) +{ + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor d0_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result({M}); + Tensor reduce1_m_host_result({M}); + + Tensor c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result({M}); + Tensor reduce1_m_device_result({M}); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + using D0ElementOp = PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + auto d0_element_op = D0ElementOp{}; + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + ReduceAccDataType acc = static_cast(c_m_n_host_result(m, n)) + + static_cast(bias_n(n)); + + ReduceAccDataType d0 = static_cast(d0_m_n(m, n)); + c_element_op(acc, acc); + d0_element_op(d0, d0); + acc += d0; + c_m_n_host_result(m, n) = static_cast(acc); + } + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpaceSize()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d0_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + + sizeof(ReduceDataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bilinear_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bilinear_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31bae281c45b2e4e9bbdd8068fd3b084270c45ad --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_bilinear_impl.hpp @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_bilinear_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD, + int StrideE, + float alpha, + float beta) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = Bilinear; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9a544c044f4165fbee8cab9ba10a10616f27e22 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using FastGelu = ck::tensor_operation::element_wise::FastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = FastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::FastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9b164104b505d3d3912206b5fa1a6ee6aeb2ef48 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_impl.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +int profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference op + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_avg_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass ? 0 : 1; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_reduce_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..370121a3ccff382539ee34994a05b4b035db5512 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F16 = ck::half_t; +using ReducePtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using DeviceGemmReduceNoOpPtr = + ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_host_result({M}); + Tensor reduce1_m_host_result({M}); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor reduce0_m_device_result({M}); + Tensor reduce1_m_device_result({M}); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl; + std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using ReduceOp0 = ck::reduce::Add; + using ReduceOp1 = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + std::array gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; + + const auto reduce0_op = ReduceOp0{}; + const auto reduce1_op = ReduceOp1{}; + + auto passthrough = UnaryIdenticElementOp{}; + auto square = UnarySquareElementOp{}; + auto div = UnaryDivElementOp{N}; + std::array reduce_in_element_ops = {&passthrough, &square}; + std::array reduce_out_element_ops = {&div, &div}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = ReduceDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType d0_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d1_val; + + square(d1_val, d0_val); + reduce0_op(reduce0_acc, d0_val); + reduce1_op(reduce1_acc, d1_val); + } + + div(reduce0_acc, reduce0_acc); + div(reduce1_acc, reduce1_acc); + reduce0_m_host_result(m) = ck::type_convert(reduce0_acc); + reduce1_m_host_result(m) = ck::type_convert(reduce1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * + reduce0_m_device_result.mDesc.GetElementSpaceSize()); + DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * + reduce1_m_device_result.mDesc.GetElementSpaceSize()); + + std::array p_reduces = {reduce0_device_buf.GetDeviceBuffer(), + reduce1_device_buf.GetDeviceBuffer()}; + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + nullptr, + {}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {}, + gemm_element_ops, + {}, + reduce_in_element_ops, + reduce_out_element_ops); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + reduce0_device_buf.SetZero(); + reduce1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); + reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_host: ", reduce0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", reduce0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_host: ", reduce1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", reduce1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_splitk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5d5f8765e1b25e67490da7444f31ac38c7aca27 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_splitk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4f9aa98376c9e56058f04938f8171ac93c230e3e --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_bwd_weight_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + ck::index_t split_k) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight_host_result(wei_g_k_c_xs_desc); + Tensor weight_device_result(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight_host_result.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + output.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + output.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + weight_device_result.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + out_device_buf.ToDevice(output.mData.data()); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight_host_result, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + bool all_pass = true; + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; + + range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); + range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); + range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); + range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); + range_copy(conv_param.input_left_pads_, begin(input_left_pads)); + range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.C_, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // using atomic add, so need to reset input + wei_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + wei_device_buf.FromDevice(weight_device_result.mData.data()); + + bool pass = ck::utils::check_err(weight_device_result, weight_host_result); + + if(!pass) + { + std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; + } + + all_pass &= pass; + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", output.mData, ",") << std::endl; + ; + LogRangeAsType( + std::cout << "weight (device): ", weight_device_result.mData, ",") + << std::endl; + ; + LogRangeAsType( + std::cout << "weight (host): ", weight_host_result.mData, ",") + << std::endl; + ; + LogRangeAsType(std::cout << "input: ", input.mData, ",") << std::endl; + ; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return all_pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b201a2ed331cbf4a0e3951d8d69027f018af09f6 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // run reference op + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "xdl found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_gemm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..04f94a0f244bf79f39f10470b09230a9f8a5fc3a --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs) +{ + + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0{}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + // if(do_verification) + // { + + // } + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto p_ds = std::vector>{}; + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + pass = pass && ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} // namespace profiler + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_groupnorm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_groupnorm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..81fec5590a8463ed68e1cf9797800f99e3fbb3c3 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_groupnorm_impl.hpp @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_groupnorm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + if(length.size() != 5) + return false; + + index_t G = length[3]; + index_t C = length[4]; + + std::vector reduce_dim = {1, 2, 4}; + std::vector gammaBetaLength = {G, C}; + std::vector gammaBetaStride = {0, 0, 0, C, 1}; + + Tensor x(length); + Tensor gamma(gammaBetaLength); + Tensor beta(gammaBetaLength); + Tensor y(length); + Tensor host_y(length); + + switch(init_method) + { + case 0: + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + beta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + beta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + beta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalization; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, 1e-6); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + length, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + gammaBetaStride, + gammaBetaStride, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + reduce_dim, + 1e-6, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + beta.mDesc.GetElementSize() * sizeof(BetaDataType) + + y.mDesc.GetElementSize() * sizeof(YDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + y_dev.FromDevice(y.mData.data()); + + bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "x : ", x.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_y : ", host_y.mData, ",") << std::endl; + LogRangeAsType(std::cout << "y : ", y.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, " + << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_layernorm_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_layernorm_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eb21d4a586090937bf35491eda492348331503a4 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_layernorm_impl.hpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_layernorm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + if(length.size() < 2) + return false; + + // Assume normalize dimension except for batch (first) dimension + std::vector reduce_length{length.begin() + 1, length.end()}; + std::vector reduce_dim; + for(int i = 1; i < Rank; ++i) + reduce_dim.push_back(i); + + Tensor x(length); + Tensor gamma(reduce_length); + Tensor beta(reduce_length); + Tensor y(length); + Tensor host_y(length); + + std::vector strideXY = + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector strideGammaBeta = strideXY; + strideGammaBeta[0] = 0; + + switch(init_method) + { + case 0: + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + beta.GenerateTensorValue(GeneratorTensor_1{}); + y.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + beta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + y.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + beta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + y.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + constexpr int NumReduceDim = Rank - 1; + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalization; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideXY, + strideGammaBeta, + strideGammaBeta, + strideXY, + reduce_dim, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), + nullptr, + nullptr, + PassThrough{}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + beta.mDesc.GetElementSize() * sizeof(BetaDataType) + + y.mDesc.GetElementSize() * sizeof(YDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + y_dev.FromDevice(y.mData.data()); + + bool pass = ck::utils::check_err( + y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "x : ", x.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_y : ", host_y.mData, ",") << std::endl; + LogRangeAsType(std::cout << "y : ", y.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "stride = ", strideXY, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_reduce_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_reduce_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ccb99398f22fc9db702a6801a0858d9456c6b855 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_reduce_impl.hpp @@ -0,0 +1,520 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" + +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_reduction.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct ReduceDescription +{ + static constexpr index_t Rank_ = Rank; + static constexpr index_t NumReduceDim_ = NumReduceDim; + static constexpr ReduceTensorOp ReduceOpId_ = ReduceOpId; + static constexpr bool PropagateNan_ = PropagateNan; + static constexpr bool UseIndex_ = UseIndex; +}; + +using reduce_description_instances = + std::tuple, // for ADD + ReduceDescription<4, 4, ReduceTensorOp::ADD, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::ADD, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::ADD, false, false>, + + ReduceDescription<4, 3, ReduceTensorOp::AVG, false, false>, // for AVG + ReduceDescription<4, 4, ReduceTensorOp::AVG, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::AVG, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::AVG, false, false>, + + ReduceDescription<4, 3, ReduceTensorOp::NORM2, false, false>, // for NORM2 + ReduceDescription<4, 4, ReduceTensorOp::NORM2, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::NORM2, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::NORM2, false, false>, + + ReduceDescription<4, 3, ReduceTensorOp::MIN, false, false>, // for MIN + ReduceDescription<4, 4, ReduceTensorOp::MIN, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::MIN, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::MIN, false, false>, + ReduceDescription<4, 3, ReduceTensorOp::MAX, false, false>, // for MAX + ReduceDescription<4, 4, ReduceTensorOp::MAX, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::MAX, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::MAX, false, false>, + ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, false>, // for AMAX + ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, false>, + ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, false>, + ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, false>, + + ReduceDescription<4, 3, ReduceTensorOp::MIN, false, true>, // for MIN + ReduceDescription<4, 4, ReduceTensorOp::MIN, false, true>, + ReduceDescription<4, 1, ReduceTensorOp::MIN, false, true>, + ReduceDescription<2, 1, ReduceTensorOp::MIN, false, true>, + ReduceDescription<4, 3, ReduceTensorOp::MAX, false, true>, // for MAX + ReduceDescription<4, 4, ReduceTensorOp::MAX, false, true>, + ReduceDescription<4, 1, ReduceTensorOp::MAX, false, true>, + ReduceDescription<2, 1, ReduceTensorOp::MAX, false, true>, + ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, true>, // for AMAX + ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, true>, + ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, true>, + ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, true>>; + +template +bool description_match(const DescriptionType& description, + int Rank, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex) +{ + if(description.Rank_ != Rank || description.ReduceOpId_ != ReduceOpId || + description.PropagateNan_ != PropagateNan || description.UseIndex_ != UseIndex) + return (false); + + if(DescriptionType::NumReduceDim_ != reduceDims.size()) + return (false); + + bool result = true; + + return (result); +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +static inline std::array +get_invariant_dims(const std::array& reduceDims) +{ + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + std::array invariantDims; + + // collect invariant dimensions + int dim = 0; + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + invariantDims[dim] = i; + dim++; + }; + + return invariantDims; +}; + +template +bool profile_reduce_impl_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::array& reduceDims, + float alpha, + float beta) +{ + using namespace ck::tensor_operation::device; + using namespace ck::tensor_operation::device::instance; + using ck::host_common::dumpBufferToFile; + + constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + constexpr bool OutputIndex = (op_support_indices && UseIndex); + + constexpr bool out_support_atomic_add = std::is_same::value; + constexpr bool op_support_atomic_add = + !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2; + constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); + + // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations + // 2) If InDataType is half_t, must use float as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_1 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is float, must use float as AccDataType for indexable reduction operations + constexpr bool invalid_reduce_2 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // 1) The indices can only be used when the reduction operation is indexable + constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex); + + // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations + // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction + // operations + constexpr bool invalid_reduce_4 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InDataType is int8_t, the supported operation must be either indexable operations or + // ADD/AVG + constexpr bool invalid_reduce_5 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); + + // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations + constexpr bool invalid_reduce_6 = + std::is_same::value && !std::is_same::value; + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || + invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); + + bool pass = true; + + if constexpr(!invalid_reduce) + { + Tensor in(inLengths); + + std::vector outLengths; + + const auto invariantDims = get_invariant_dims(reduceDims); + + if(reduceDims.size() == Rank) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int) : 0; + + DeviceMem out_indices_dev(indicesSizeInBytes); + + float best_avg_time = 0; + float best_gb_per_sec = 0; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + using ReduceOperation = typename reduce_binary_operator::opType; + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + using DeviceReduceInstPtr = + DeviceReducePtr; + + std::vector reduce_ptrs; + + add_device_reduce_instance_threadwise(reduce_ptrs); + + add_device_reduce_instance_blockwise(reduce_ptrs); + + if constexpr(use_atomic_add) + { + add_device_reduce_instance_multiblock_atomic_add(reduce_ptrs); + } + + if(reduce_ptrs.empty()) + { + throw std::runtime_error("Wrong! No device REDUCE instance found"); + }; + + if(do_verification) + { + ReductionHost + hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); + }; + + std::array arrInLengths; + std::array arrInStrides; + std::array arrOutLengths; + std::array arrOutStrides; + + ck::ranges::copy(inLengths, arrInLengths.begin()); + ck::ranges::copy(inStrides, arrInStrides.begin()); + ck::ranges::copy(outLengths, arrOutLengths.begin()); + ck::ranges::copy(outStrides, arrOutStrides.begin()); + + for(auto& reduce_ptr : reduce_ptrs) + { + auto argument_ptr = reduce_ptr->MakeArgumentPointer(arrInLengths, + arrInStrides, + arrOutLengths, + arrOutStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_indices_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); + + if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) + continue; + + std::string reduce_name = reduce_ptr->GetTypeString(); + + auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + invariant_total_length * reduce_total_length * sizeof(InDataType) + + invariant_total_length * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << reduce_name << std::endl; + + if(gb_per_sec > best_gb_per_sec) + { + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + bool single_pass; + + out_dev.FromDevice(out.mData.data()); + single_pass = ck::utils::check_err(out, out_ref); + + if(OutputIndex) + { + out_indices_dev.FromDevice(out_indices.mData.data()); + single_pass = single_pass && ck::utils::check_err(out_indices, out_indices_ref); + }; + + if(!single_pass) + { + std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; + } + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + dumpBufferToFile("dump_in.bin", in.mData.data(), in.mDesc.GetElementSize()); + dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize()); + dumpBufferToFile( + "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize()); + if(OutputIndex) + { + dumpBufferToFile("dump_indices.bin", + out_indices.mData.data(), + out_indices.mDesc.GetElementSize()); + dumpBufferToFile("dump_indices_host.bin", + out_indices_ref.mData.data(), + out_indices_ref.mDesc.GetElementSize()); + }; + }; + }; + + if(time_kernel) + std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s" + << std::endl; + } + else + { + std::cout << "The requested reduction operation is not supported, please check !!!" + << std::endl; + }; + + return pass; +}; + +template +bool profile_reduce_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + ReduceTensorOp ReduceOpId, + bool PropagateNan, + bool UseIndex, + float alpha, + float beta) +{ + bool matched = false; + bool pass = true; + + using tuple_of_description_instances = + tensor_operation::device::instance::reduce_description_instances; + + const auto tuple_object = tuple_of_description_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using descType = remove_cvref_t(tuple_object))>; + + if(!description_match( + descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex)) + return; + + std::array arrReduceDims; + + ck::ranges::copy(reduceDims, arrReduceDims.begin()); + + pass = pass && profile_reduce_impl_impl(descType::ReduceOpId_), + descType::PropagateNan_, + descType::UseIndex_>(do_verification, + init_method, + do_dumpout, + time_kernel, + inLengths, + arrReduceDims, + alpha, + beta); + + matched = true; + }); + + return pass; +}; + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/include/profiler/profile_softmax_impl.hpp b/3rdparty/composable_kernel/profiler/include/profiler/profile_softmax_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..090cdaaa9a2a8b68701ea3e9adb5d8c678a5355d --- /dev/null +++ b/3rdparty/composable_kernel/profiler/include/profiler/profile_softmax_impl.hpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace profiler { + +enum struct SoftmaxDataType +{ + F32_F32, // in, out + F16_F16, + BF16_BF16, + INT8_INT8, +}; + +// clang-format off +template std::string type_to_string(); +template <> std::string type_to_string() { return "f32"; } +template <> std::string type_to_string() { return "f16"; } +template <> std::string type_to_string() { return "bf16"; } +template <> std::string type_to_string() { return "int8"; } +template <> std::string type_to_string() { return "int32"; } +// clang-format on + +template +bool profile_softmax_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, + std::vector in_strides, + std::vector reduce_dims, + AccDataType alpha, + AccDataType beta) +{ + if(Rank != in_length.size()) + { + throw std::runtime_error("Input tensor rank is different from template argument Rank!"); + } + + Tensor in = in_strides.empty() ? Tensor(in_length) + : Tensor(in_length, in_strides); + Tensor out(in.mDesc); + Tensor prior_out(in.mDesc); + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(in.begin(), in.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(prior_out.begin(), + prior_out.end()); + break; + default: + ck::utils::FillUniformDistribution{0.0f, 1.0f}(in); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(prior_out); + } + + Tensor out_ref(prior_out); + + if(do_verification) + { + using ReferenceSoftmax = + tensor_operation::host::ReferenceSoftmax; + ReferenceSoftmax{}.MakeInvoker().Run({in, out_ref, alpha, beta, reduce_dims}); + } + + DeviceMem in_dev(in.GetElementSpaceSizeInBytes()); + DeviceMem out_dev(out.GetElementSpaceSizeInBytes()); + in_dev.ToDevice(in.data()); + + std::vector in_tensor_lengths(in.GetLengths().begin(), in.GetLengths().end()); + std::vector in_tensor_strides(in.GetStrides().begin(), in.GetStrides().end()); + + // add device softmax instances + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using DeviceOp = tensor_operation::device:: + DeviceSoftmax; + + // get device op instances + const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + std::cout << "found " << instances.size() << " instances" << std::endl; + + if(instances.size() <= 0) + { + throw std::runtime_error("wrong! no device normalization instance found"); + } + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + std::vector instance_pass; + + for(auto& inst_ptr : instances) + { + // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3 + // problem to rank 4 kernel) other than invoking IsSupportedArgument()? + if(!(inst_ptr->GetNumReduceDim() == static_cast(reduce_dims.size()))) + { + continue; + } + + auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths, + in_tensor_strides, + reduce_dims, + &alpha, + &beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]"; + LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl; + instance_pass.push_back(true); + continue; + } + + out_dev.ToDevice(prior_out.data()); + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + std::size_t num_bytes = + in.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.GetElementSize() * sizeof(OutDataType); + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + + if(do_verification) + { + out_dev.FromDevice(out.data()); + bool pass = true; + if(std::is_same::value) + { + pass = pass && ck::utils::check_err( + out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + else + { + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_ref : ", out_ref.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + } + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "input lengths = [", in_length, ", ") + << "], " + << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + } + instance_pass.push_back(pass); + } + } + if(time_kernel) + { + std::cout << "Best Perf for datatype = " << type_to_string() << "_" + << type_to_string() << ", "; + LogRange(std::cout << "length = ", in_tensor_lengths, ",") << ", "; + LogRange(std::cout << "stride = ", in_tensor_strides, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; + std::cout << "alpha = " << alpha << ", " + << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_instance_name << std::endl; + } + return std::all_of( + std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; }); +} + +} // namespace profiler +} // namespace ck diff --git a/3rdparty/composable_kernel/profiler/src/CMakeLists.txt b/3rdparty/composable_kernel/profiler/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc87554cecec617659b9b05200a7e3dcdfbb3fd7 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/CMakeLists.txt @@ -0,0 +1,67 @@ +# ckProfiler +set(PROFILER_SOURCES + profiler.cpp + profile_gemm.cpp + profile_gemm_splitk.cpp + profile_gemm_bilinear.cpp + profile_gemm_bias_add_reduce.cpp + profile_gemm_add_add_fastgelu.cpp + profile_gemm_add_fastgelu.cpp + profile_gemm_fastgelu.cpp + profile_gemm_reduce.cpp + profile_batched_gemm.cpp + profile_batched_gemm_gemm.cpp + profile_batched_gemm_add_relu_gemm_add.cpp + profile_batched_gemm_reduce.cpp + profile_grouped_gemm.cpp + profile_conv_fwd.cpp + profile_conv_fwd_bias_relu.cpp + profile_conv_fwd_bias_relu_add.cpp + profile_conv_bwd_data.cpp + profile_grouped_conv_fwd.cpp + profile_grouped_conv_bwd_weight.cpp + profile_reduce.cpp + profile_groupnorm.cpp + profile_layernorm.cpp + profile_softmax.cpp + profile_batchnorm_fwd.cpp + profile_batchnorm_bwd.cpp +) + +set(PROFILER_EXECUTABLE ckProfiler) + +add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) +target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) + +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) + +rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/3rdparty/composable_kernel/profiler/src/profile_batched_gemm.cpp b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..907a373794f2d8a59a066b94a591c4890c5c4038 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm.cpp @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "batched_gemm" +#define OP_DESC "Batched GEMM" + +int profile_batched_gemm(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"); + printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); + printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n"); + printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n"); + printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchStrideA = std::stoi(argv[14]); + const int BatchStrideB = std::stoi(argv[15]); + const int BatchStrideC = std::stoi(argv[16]); + + const int BatchCount = std::stoi(argv[17]); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + bool pass = ck::profiler:: + profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f440a3094eb16b4eef1c5dee46cb08f6ab25d933 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp" +#include "profiler_operation_registry.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +#define OP_NAME "batched_gemm_add_relu_gemm_add" +#define OP_DESC "Batched GEMM+Add+Relu+GEMM+Add" + +int profile_batched_gemm_add_relu_gemm_add(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_NK_MN_NO_MO_MO, // 0 + MK_NK_MN_ON_MO_MO, // 1 + }; + + enum struct GemmDataType + { + F32_F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16_F16, // 1 + }; + + GemmDataType data_type = GemmDataType::F16_F16_F16_F16_F16_F16; + GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_MN_NO_MO_MO; + bool do_verification = true; + int init_method = 1; + bool do_log = 0; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA0 = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideD0 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideD1 = -1; + ck::index_t StrideE1 = -1; + ck::index_t BatchStrideA0 = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideD0 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideD1 = -1; + ck::index_t BatchStrideE1 = -1; + + if(argc == 8) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + } + else if(argc == 13) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + + M = std::stoi(argv[8]); + N = std::stoi(argv[9]); + K = std::stoi(argv[10]); + O = std::stoi(argv[11]); + BatchCount = std::stoi(argv[12]); + } + else if(argc == 25) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + + M = std::stoi(argv[8]); + N = std::stoi(argv[9]); + K = std::stoi(argv[10]); + O = std::stoi(argv[11]); + BatchCount = std::stoi(argv[12]); + + StrideA0 = std::stoi(argv[13]); + StrideB0 = std::stoi(argv[14]); + StrideD0 = std::stoi(argv[15]); + StrideB1 = std::stoi(argv[16]); + StrideD1 = std::stoi(argv[17]); + StrideE1 = std::stoi(argv[18]); + + BatchStrideA0 = std::stoi(argv[19]); + BatchStrideB0 = std::stoi(argv[20]); + BatchStrideD0 = std::stoi(argv[21]); + BatchStrideB1 = std::stoi(argv[22]); + BatchStrideD1 = std::stoi(argv[23]); + BatchStrideE1 = std::stoi(argv[24]); + } + else + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (1: fp16)\n"); + printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] " + "= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = " + "E1[m, o];)\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 12: M, N, K, O, Batch\n"); + printf("arg13 to 18: StrideA0, StrideB0, StrideD0, StrideB1, StrideD1, StrideE1\n"); + printf("arg19 to 24: BatchStrideA0, BatchStrideB0, BatchStrideD0, BatchStrideB1, " + "BatchStrideD1, BatchStrideE1 \n"); + exit(1); + } + + if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 && + layout == GemmMatrixLayout::MK_NK_MN_NO_MO_MO) + { + ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl, // D0sLayout, + Row, // B1Layout, + ck::Tuple, // D1sLayout, + Row, // E1Layout, + F16, // A0DataType, + F16, // B0DataType, + ck::Tuple, // D0DataType, + F16, // B1DataType, + ck::Tuple, // D1sDataType + F16> // E1DataType, + (do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + StrideD0, + StrideB1, + StrideD1, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideD0, + BatchStrideB1, + BatchStrideD1, + BatchStrideE1); + } + else if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 && + layout == GemmMatrixLayout::MK_NK_MN_ON_MO_MO) + { + ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl, // D0sLayout, + Col, // B1Layout, + ck::Tuple, // D1sLayout, + Row, // E1Layout, + F16, // A0DataType, + F16, // B0DataType, + ck::Tuple, // D0DataType, + F16, // B1DataType, + ck::Tuple, // D1sDataType + F16> // E1DataType, + (do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + StrideD0, + StrideB1, + StrideD1, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideD0, + BatchStrideB1, + BatchStrideD1, + BatchStrideE1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_add_relu_gemm_add); diff --git a/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_gemm.cpp b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6015c93be35268c1c87c129d0cc7d2c9bdf0584f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_gemm.cpp @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +#define OP_NAME "batched_gemm_gemm" +#define OP_DESC "Batched GEMM+GEMM" + +int profile_batched_gemm_gemm(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_NK_NO_MO, // 0 + MK_NK_ON_MO, // 0 + }; + + enum struct GemmDataType + { + F32_F32_F32_F32, // 0 + F16_F16_F16_F16, // 1 + }; + + GemmDataType data_type = GemmDataType::F16_F16_F16_F16; + GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_NO_MO; + bool do_verification = true; + int init_method = 1; + bool do_log = 0; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA0 = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideE1 = -1; + ck::index_t BatchStrideA0 = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideE1 = -1; + + if(argc == 8) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + } + else if(argc == 13) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + + M = std::stoi(argv[8]); + N = std::stoi(argv[9]); + K = std::stoi(argv[10]); + O = std::stoi(argv[11]); + BatchCount = std::stoi(argv[12]); + } + else if(argc == 21) + { + data_type = static_cast(std::stoi(argv[2])); + layout = static_cast(std::stoi(argv[3])); + do_verification = std::stoi(argv[4]); + init_method = std::stoi(argv[5]); + do_log = std::stoi(argv[6]); + time_kernel = std::stoi(argv[7]); + + M = std::stoi(argv[8]); + N = std::stoi(argv[9]); + K = std::stoi(argv[10]); + O = std::stoi(argv[11]); + BatchCount = std::stoi(argv[12]); + + StrideA0 = std::stoi(argv[13]); + StrideB0 = std::stoi(argv[14]); + StrideB1 = std::stoi(argv[15]); + StrideE1 = std::stoi(argv[16]); + + BatchStrideA0 = std::stoi(argv[17]); + BatchStrideB0 = std::stoi(argv[18]); + BatchStrideB1 = std::stoi(argv[19]); + BatchStrideE1 = std::stoi(argv[20]); + } + else + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (1: fp16)\n"); + printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] " + "= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = E1[m, " + "o];)\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 12: M, N, K, O, Batch\n"); + printf("arg13 to 16: StrideA0, StrideB0, StrideB1, StrideE1\n"); + printf("arg17 to 20: BatchStrideA0, BatchStrideB0, BatchStrideB1, BatchStrideE1 \n"); + exit(1); + } + + if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_NO_MO) + { + ck::profiler::profile_batched_gemm_gemm_impl // E1Layout, + (do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + StrideB1, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideB1, + BatchStrideE1); + } + else if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_ON_MO) + { + ck::profiler::profile_batched_gemm_gemm_impl // E1Layout, + (do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + StrideB1, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideB1, + BatchStrideE1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_gemm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_reduce.cpp b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b1dfc01427a1cd34eefc89584bb6f3ce8bf206f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batched_gemm_reduce.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_reduce_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "batched_gemm_reduce" +#define OP_DESC "Batched GEMM+Reduce" + +int profile_batched_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(argc != 15) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchCount = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_reduce); diff --git a/3rdparty/composable_kernel/profiler/src/profile_batchnorm_bwd.cpp b/3rdparty/composable_kernel/profiler/src/profile_batchnorm_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44ce7350ff06eaad3eff7d93f60c022d1a4c2b60 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batchnorm_bwd.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "profiler/profile_batchnorm_backward_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +using namespace std; + +static const struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchnormBwdArgParser +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + + bool do_verification = false; + bool do_dumpout = false; + + bool haveSavedMeanInvVar; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + + BatchnormBwdArgParser() = default; + ~BatchnormBwdArgParser() = default; + + void show_usage(const char* cmd) + { + // clang-format off + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl; + std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl; + std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl; + std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl; + // clang-format on + }; + + int operator()(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return -1; + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return -1; + }; + }; + + if(optind + 4 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + haveSavedMeanInvVar = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return -1; + + return 0; + }; +}; // end of class AppArgs + +static const double epsilon = std::numeric_limits::epsilon(); + +int profile_batchnorm_backward(int argc, char* argv[]) +{ + using ck::profiler::profile_batchnorm_backward_impl; + + BatchnormBwdArgParser arg_parser; + + if(arg_parser(argc, argv) != 0) + return -1; + + using F16 = ck::half_t; + using F32 = float; + using BF16 = ck::bhalf_t; + using F64 = double; + + if(arg_parser.data_type == 0) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 1) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 5) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 6) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("bnorm_bwd", "Batchnorm backward", profile_batchnorm_backward); diff --git a/3rdparty/composable_kernel/profiler/src/profile_batchnorm_fwd.cpp b/3rdparty/composable_kernel/profiler/src/profile_batchnorm_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..902a1fc423f98880b43e1551f3305854d1a0b9b8 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_batchnorm_fwd.cpp @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "profiler/profile_batchnorm_forward_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +using namespace std; + +static const struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchnormFwdArgParser +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + + bool do_verification = false; + bool do_dumpout = false; + + bool updateMovingAverage; + bool saveMeanAndInvVariance; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + + BatchnormFwdArgParser() = default; + ~BatchnormFwdArgParser() = default; + + void show_usage(const char* cmd) + { + // clang-format off + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl; + std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)" << std::endl; + std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance (0=no, 1=yes)" << std::endl; + std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl; + std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; + // clang-format on + }; + + int operator()(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return -1; + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return -1; + }; + }; + + if(optind + 5 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + updateMovingAverage = std::atoi(argv[optind++]); + saveMeanAndInvVariance = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return -1; + + return 0; + }; +}; // end of class AppArgs + +static const double epsilon = std::numeric_limits::epsilon(); +static const double averageFactor = 0.1; + +int profile_batchnorm_forward(int argc, char* argv[]) +{ + using ck::profiler::profile_batchnorm_forward_impl; + + BatchnormFwdArgParser arg_parser; + + if(arg_parser(argc, argv) != 0) + return -1; + + using F16 = ck::half_t; + using F32 = float; + using BF16 = ck::bhalf_t; + using F64 = double; + + if(arg_parser.data_type == 0) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_forward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.updateMovingAverage, + arg_parser.saveMeanAndInvVariance, + epsilon, + averageFactor); + }; + } + else if(arg_parser.data_type == 1) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_forward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.updateMovingAverage, + arg_parser.saveMeanAndInvVariance, + epsilon, + averageFactor); + }; + } + else if(arg_parser.data_type == 5) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_forward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.updateMovingAverage, + arg_parser.saveMeanAndInvVariance, + epsilon, + averageFactor); + }; + } + else if(arg_parser.data_type == 6) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_forward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.updateMovingAverage, + arg_parser.saveMeanAndInvVariance, + epsilon, + averageFactor); + }; + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("bnorm_fwd", "Batchnorm forward", profile_batchnorm_forward); diff --git a/3rdparty/composable_kernel/profiler/src/profile_conv_bwd_data.cpp b/3rdparty/composable_kernel/profiler/src/profile_conv_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9241ead738ef8a8a67d8682d743f6c9bd0faf640 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_conv_bwd_data.cpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_conv_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + NCHW_KCYX_NKHW, // 0 + NHWC_KYXC_NHWK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "conv_bwd_data" +#define OP_DESC "Convolution Backward Data" + +static void print_helper_msg() +{ + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8)\n" + << "arg3: tensor layout (0: Input[N, C, Hi, Wi], Weight[K, C, Y, X], Output[N, K, Ho, Wo]\n" + << " 1: Input[N, Hi, Wi, C], Weight[K, Y, X, C], Output[N, Ho, Wo, " + "K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +} // namespace + +int profile_conv_bwd_data(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + bool pass = ck::profiler::profile_conv_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_conv_bwd_data); diff --git a/3rdparty/composable_kernel/profiler/src/profile_conv_fwd.cpp b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b57ee7fd94261f4437d1b62f0b6450ea74522db9 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_conv_fwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + NCHW_KCYX_NKHW, // 0 + NHWC_KYXC_NHWK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "conv_fwd" +#define OP_DESC "Convolution Forward" + +static void print_helper_msg() +{ + std::cout + // clang-format-off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8)\n" + << "arg3: tensor layout (0: Input[N, C, Hi, Wi], Weight[K, C, Y, X], Output[N, K, Ho, Wo]\n" + << " 1: Input[N, Hi, Wi, C], Weight[K, Y, X, C], Output[N, Ho, Wo, " + "K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format-on +} + +} // namespace + +int profile_conv_fwd(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + bool pass = ck::profiler::profile_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_conv_fwd); diff --git a/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu.cpp b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b44007cde4742efcec188f0b620e2cf83295f5b6 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_conv_fwd_bias_relu_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +#define OP_NAME "conv_fwd_bias_relu" +#define OP_DESC "Convolution Forward+Bias+ReLU" + +int profile_conv_fwd_bias_relu(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_conv_fwd_bias_relu); diff --git a/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..408dd02f78dd6241b5f0a67ae7aac4420c891943 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_conv_fwd_bias_relu_add_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +enum struct ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum struct ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum struct ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +#define OP_NAME "conv_fwd_bias_relu_add" +#define OP_DESC "Convolution Forward+Bias+ReLU+Add" + +int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto in_layout = static_cast(std::stoi(argv[3])); + const auto wei_layout = static_cast(std::stoi(argv[4])); + const auto out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_fwd_bias_relu_add_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + time_kernel, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_conv_fwd_bias_relu_add); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61bae6ae70ebcf965618a100c3f553de73182118 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "gemm" +#define OP_DESC "GEMM" + +static void print_helper_msg() +{ + std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << " 2: A[k, m] * B[k, n] = C[m, n];\n" + << " 3: A[k, m] * B[n, k] = C[m, n])\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n" + << std::endl; +} + +int profile_gemm(int argc, char* argv[]) +{ + if(argc != 14) + { + print_helper_msg(); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + using INT32 = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_layout, + auto b_layout, + auto c_layout, + auto a_type, + auto b_type, + auto acc_type, + auto c_type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = + ck::profiler::profile_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_add_add_fastgelu.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_add_add_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3c0fb7b67daf131ab53194bbf28a12aeedf62aa --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_add_fastgelu_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add_add_fastgelu" +#define OP_DESC "GEMM+Add+Add+FastGeLU" + +int profile_gemm_add_add_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + KM_KN_MN_MN_MN, // 2 + KM_NK_MN_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n] + D1[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto d1_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using D1Layout = decltype(d1_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_add_fastgelu); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_add_fastgelu.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_add_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..380b25a614c58d42209135218cca331de3bebe33 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_add_fastgelu.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_fastgelu_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add_fastgelu" +#define OP_DESC "GEMM+Add+FastGeLU" + +int profile_gemm_add_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32, // 0 + F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 15) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_fastgelu); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_bias_add_reduce.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_bias_add_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d86db08223a771ddc4373c014a6a2a42f92285b --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_bias_add_reduce_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_bias_add_reduce" +#define OP_DESC "GEMM+Bias+Add+Reduce" + +int profile_gemm_bias_add_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_bias_add_reduce); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_bilinear.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_bilinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3480014ba6e6a2088c1ee1a221eb99eee6191241 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_bilinear.cpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_bilinear_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_bilinear" +#define OP_DESC "GEMM+Bilinear" + +int profile_gemm_bilinear(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32, // 0 + F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 17) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = alpha * A[m, k] * B[k, n] + beta * D[m, n];\n"); + printf(" 1: E[m, n] = alpha * A[m, k] * B[n, k] + beta * D[m, n];\n"); + printf(" 2: E[m, n] = alpha * A[k, m] * B[k, n] + beta * D[m, n];\n"); + printf(" 3: E[m, n] = alpha * A[k, m] * B[n, k] + beta * D[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + printf("arg15 to 16: alhpa, beta\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + const float alpha = std::stof(argv[15]); + const float beta = std::stof(argv[16]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using DDataType = decltype(d_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using DLayout = decltype(d_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_bilinear_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD < 0) ? DefaultStrideD : StrideD, + (StrideE < 0) ? DefaultStrideE : StrideE, + alpha, + beta); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::MK_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_KN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16 && layout == MatrixLayout::KM_NK_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_bilinear); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_fastgelu.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_fastgelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2a137224cb096f4bf2968f06be13d71d44614794 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_fastgelu.cpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_fastgelu_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_fastgelu" +#define OP_DESC "GEMM+FastGeLU" + +int profile_gemm_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct MatrixDataType + { + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + }; + + if(argc != 14) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideE = std::stoi(argv[13]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto e_type, + auto a_layout, + auto b_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16 && layout == MatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16 && layout == MatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16 && layout == MatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16 && layout == MatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_fastgelu); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_reduce.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..395bf0627e43617d04d30a903a1f3cbdecf4a557 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_reduce.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_reduce_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_reduce" +#define OP_DESC "GEMM+Reduce" + +int profile_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_reduce); diff --git a/3rdparty/composable_kernel/profiler/src/profile_gemm_splitk.cpp b/3rdparty/composable_kernel/profiler/src/profile_gemm_splitk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f636ce718c669feb23a60fb1853a1219fea3da55 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_gemm_splitk.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_splitk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "gemm_splitk" +#define OP_DESC "Split-K GEMM" + +int profile_gemm_splitk(int argc, char* argv[]) +{ + if(argc != 15) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + using F32 = float; + using F16 = ck::half_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_splitk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_splitk); diff --git a/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_bwd_weight.cpp b/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_bwd_weight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dfd8a099f5408f76fd51a9c69a988591edf5348f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + GNCHW_GKCYX_GNKHW, // 0 + GNHWC_GKYXC_GNHWK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_F32_BF16, // 2 +}; + +#define OP_NAME "grouped_conv_bwd_weight" +#define OP_DESC "Grouped Convolution Backward Weight" + +static void print_helper_msg() +{ + std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight fp32, Output bf16)\n" + << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " + "N, K, Ho, Wo]\n" + << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " + "N, Ho, Wo, K]\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n" + << std::endl; +} + +} // namespace + +int profile_grouped_conv_bwd_weight(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); + split_k = std::max(1, split_k); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl( + do_verification, init_method, do_log, time_kernel, params, split_k); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_bwd_weight); diff --git a/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_fwd.cpp b/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ff3c15af05a5f67988ef4842e00e2671289bb1b --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_grouped_conv_fwd.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "grouped_conv_fwd" +#define OP_DESC "Grouped Convolution Forward" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +} // namespace + +int profile_grouped_conv_fwd(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + // + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + // + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + // GNHWC_GKYXC_GNHWK + if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}); + } + } + // NHWGC_GKYXC_NHWGK + else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_fwd); diff --git a/3rdparty/composable_kernel/profiler/src/profile_grouped_gemm.cpp b/3rdparty/composable_kernel/profiler/src/profile_grouped_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65e24bd9cc1fba07ab17010c1a95fdd02d991dfe --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_grouped_gemm.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +#define OP_NAME "grouped_gemm" +#define OP_DESC "Grouped GEMM" + +int profile_grouped_gemm(int argc, char* argv[]) +{ + if(!(argc == 14)) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_groupnorm.cpp b/3rdparty/composable_kernel/profiler/src/profile_groupnorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2741f52717a5b7e370be2ddf3420273fc78df0e5 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_groupnorm.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_groupnorm_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct GroupnormArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +#define OP_NAME "groupnorm" +#define OP_DESC "Group Normalization" + +void print_help_groupnorm() +{ + std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp16; 1: fp32)\n" + << "arg3: verification (0: no; 1: yes)\n" + << "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg5: print tensor value (0: no; 1: yes)\n" + << "arg6: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n" + << std::endl; +} + +int profile_groupnorm(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = false; + int init_method = 0; + bool do_log = 0; + bool time_kernel = 1; + std::vector length = {64, 16, 16, 32, 40}; + + if(argc != 1 && argc != 13) + { + print_help_groupnorm(); + return 0; + } + + if(argc == 13) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + + // parse the long options + GroupnormArgParser arg_parser; + arg_parser(argc, argv); + length = arg_parser.long_opts["length"]; + } + + using F16 = ck::half_t; + using F32 = float; + + if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_groupnorm_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_groupnorm_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_groupnorm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_layernorm.cpp b/3rdparty/composable_kernel/profiler/src/profile_layernorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e93fc2dbd2bf004389eed968e0518a2cdb1ea4f6 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_layernorm.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_layernorm_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct LayernormArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_layernorm() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1024 1024) \n" + << std::endl; +} + +int profile_layernorm(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_layernorm(); + return 0; + } + + LayernormArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F16 = ck::half_t; + using F32 = float; + constexpr int rank = 2; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_layernorm_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_layernorm_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("layernorm", "Layer Normalization", profile_layernorm); diff --git a/3rdparty/composable_kernel/profiler/src/profile_reduce.cpp b/3rdparty/composable_kernel/profiler/src/profile_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6925371858ee913dcb5d2b529c3e350a6eec5f4f --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_reduce.cpp @@ -0,0 +1,434 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/reduction_enums.hpp" + +#include "ck/library/utility/host_common_util.hpp" + +#include "profiler/profile_reduce_impl.hpp" +#include "profiler/data_type_enum.hpp" +#include "profiler_operation_registry.hpp" + +using namespace std; + +using ck::ReduceTensorOp; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"reduceOp", required_argument, nullptr, 'O'}, + {"compType", required_argument, nullptr, 'C'}, + {"outType", required_argument, nullptr, 'W'}, + {"nanOpt", required_argument, nullptr, 'N'}, + {"indicesOpt", required_argument, nullptr, 'I'}, + {"scales", required_argument, nullptr, 'S'}, + {"half", no_argument, nullptr, '?'}, + {"double", no_argument, nullptr, '?'}, + {"int8", no_argument, nullptr, '?'}, + {"bf16", no_argument, nullptr, '?'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +static void check_reduce_dims(const int rank, const std::vector& reduceDims) +{ + for(auto dim : reduceDims) + { + if(dim < 0 || dim >= rank) + throw std::runtime_error("Invalid dimension index specified for Reducing"); + }; + + unsigned int flag = 0; + + for(auto dim : reduceDims) + { + if(flag & (0x1 << dim)) + throw std::runtime_error("All toReduce dimensions should be different!"); + flag = flag | (0x1 << dim); + }; +}; + +class ReduceProfilerArgs +{ + private: + int option_index = 0; + + public: + bool use_half = false; + bool use_double = false; + bool use_int8 = false; + bool use_bf16 = false; + + std::vector inLengths; + std::vector outLengths; + std::vector reduceDims; + + std::vector scales; + + ReduceTensorOp reduceOp = ReduceTensorOp::ADD; + ck::DataTypeEnum compTypeId = ck::DataTypeEnum::Float; + ck::DataTypeEnum outTypeId = ck::DataTypeEnum::Float; + + bool compType_assigned = false; + bool outType_assigned = false; + + int nanOpt = 0; + int indicesOpt = 0; + bool do_verification = false; + bool do_dumpout = false; + + int init_method; + bool time_kernel; + + ReduceProfilerArgs() = default; + ~ReduceProfilerArgs() = default; + + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--reduceOp or -O, enum value indicating the reduction operations" + << std::endl; + std::cout << "--compType or -C, enum value indicating the type of accumulated values used " + "during the reduction" + << std::endl; + std::cout << "--outType or -W, optional enum value indicating the type of the reduced " + "output, which could be float when the input data is half" + << std::endl; + std::cout + << "--nanOpt or -N, 1/0 value indicates the selection to use or not use Nan-Propagation" + << std::endl; + std::cout << "--indicesOpt or -I, 1/0 value indicates the selection to use or not use " + "index in reduction" + << std::endl; + std::cout << "--scales or -S, comma separated two float values for alpha and beta" + << std::endl; + std::cout << "--half, use fp16 for the input and output tensor data types" << std::endl; + std::cout << "--double, use fp64 for the input and output tensor data types" << std::endl; + std::cout << "--int8, use int8 for the input and output tensor data types" << std::endl; + std::cout << "--bf16, use bfloat16 for the input and output tensor data types" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "--dumpout or -o, 1/0 to indicate where to save the reduction result to files " + "for further analysis" + << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the "reduce" module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:O:C:W:N:I:S:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'O': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceOp = static_cast(std::atoi(optarg)); + break; + case 'C': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + compTypeId = static_cast(std::atoi(optarg)); + compType_assigned = true; + break; + case 'W': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + outTypeId = static_cast(std::atoi(optarg)); + outType_assigned = true; + break; + case 'N': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + nanOpt = std::atoi(optarg); + break; + case 'I': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + indicesOpt = std::atoi(optarg); + break; + case 'S': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + scales = getTypeValuesFromString(optarg); + + if(scales.size() != 2) + throw std::runtime_error("Invalid option format!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "half") + use_half = true; + else if(std::string(long_options[option_index].name) == "double") + use_double = true; + else if(std::string(long_options[option_index].name) == "int8") + use_int8 = true; + else if(std::string(long_options[option_index].name) == "bf16") + use_bf16 = true; + else if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + if(reduceOp == ReduceTensorOp::MIN || reduceOp == ReduceTensorOp::MAX || + reduceOp == ReduceTensorOp::AMAX) + { + // for indexable operations, no need to assign compType and outType, just let them be + // same as inType + compType_assigned = false; + outType_assigned = false; + }; + + return (0); + }; + +}; // end of class AppArgs + +int profile_reduce(int argc, char* argv[]) +{ + using ck::DataTypeEnum; + using ck::profiler::profile_reduce_impl; + + ReduceProfilerArgs args; + + if(args.processArgs(argc, argv) < 0) + return (-1); + + int rank = args.inLengths.size(); + + check_reduce_dims(rank, args.reduceDims); + + if(args.reduceOp == ReduceTensorOp::MUL || args.reduceOp == ReduceTensorOp::NORM1) + throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!"); + + if(args.use_half) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Half; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Half && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Half; + + if(args.compTypeId == DataTypeEnum::Half) + { + profile_reduce_impl( + args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.use_int8) + { + if(!args.compType_assigned) + args.compTypeId = DataTypeEnum::Int8; + + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::Int8 && args.outTypeId != DataTypeEnum::Int32)) + args.outTypeId = DataTypeEnum::Int32; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::Int8; + + if(args.compTypeId == DataTypeEnum::Int8) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Int32) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + } + else if(args.use_bf16) + { + if(args.outType_assigned && + (args.outTypeId != DataTypeEnum::BFloat16 && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; + + if(!args.outType_assigned) + args.outTypeId = DataTypeEnum::BFloat16; + + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + { + if(args.compTypeId == DataTypeEnum::Float) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else if(args.compTypeId == DataTypeEnum::Double) + { + profile_reduce_impl(args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); + } + else + throw std::runtime_error("Invalid compType assignment!"); + }; + + return (0); +}; + +REGISTER_PROFILER_OPERATION("reduce", "Reduce", profile_reduce); diff --git a/3rdparty/composable_kernel/profiler/src/profile_softmax.cpp b/3rdparty/composable_kernel/profiler/src/profile_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30f627dd29e57f3fda5bef57353f943fca621701 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profile_softmax.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/profile_softmax_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; +using ck::profiler::SoftmaxDataType; + +struct ArgParser +{ + std::unordered_map> long_opts = { + {"length", {}}, {"stride", {}}, {"reduce", {}}, {"alpha", {}}, {"beta", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help() +{ + std::cout << "arg1: tensor operation (softmax)\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg3: verification (0: no; 1: yes)\n" + << "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg5: print tensor value (0: no; 1: yes)\n" + << "arg6: time kernel (0=n0, 1=yes)\n" + << "--length: tensor extents (e.g, --length 8 4 256) \n" + << "--stride: tensor strides (e.g, --stride 1024 256 1)\n" + << "--reduce: to-reduce dimensions (e.g, --reduce 2)\n" + << "--alpha: alpha scaling value\n" + << "--beta: beta scaling value\n" + << std::endl; +} + +int profile_softmax(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help(); + return 0; + } + + ArgParser arg_parser; + + // short unnamed options + const SoftmaxDataType data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + const std::vector stride = arg_parser.long_opts["stride"]; + const std::vector reduce = arg_parser.long_opts["reduce"]; + const index_t alpha = + arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0]; + const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0]; + + // Rank 3 + if(length.size() == 3) + { + if(data_type == SoftmaxDataType::F16_F16) + { + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta)); + } + else if(data_type == SoftmaxDataType::F32_F32) + { + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta)); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + // Rank 4 + else if(length.size() == 4) + { + if(data_type == SoftmaxDataType::F16_F16) + { + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta)); + } + else if(data_type == SoftmaxDataType::F32_F32) + { + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + float(alpha), + float(beta)); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +// hijack main() for quick debugging +// int main(int argc, char* argv[]) +// { +// profile_normalization(argc, argv); +// return 0; +// } + +REGISTER_PROFILER_OPERATION("softmax", "Softmax", profile_softmax); diff --git a/3rdparty/composable_kernel/profiler/src/profiler.cpp b/3rdparty/composable_kernel/profiler/src/profiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..080117e390c4df1bc14195649382267bcccc18f7 --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profiler.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "profiler_operation_registry.hpp" + +static void print_helper_message() +{ + std::cout << "arg1: tensor operation " << ProfilerOperationRegistry::GetInstance() << std::endl; +} + +int main(int argc, char* argv[]) +{ + if(argc == 1) + { + print_helper_message(); + } + else if(const auto operation = ProfilerOperationRegistry::GetInstance().Get(argv[1]); + operation.has_value()) + { + return (*operation)(argc, argv); + } + else + { + std::cerr << "cannot find operation: " << argv[1] << std::endl; + return EXIT_FAILURE; + } +} diff --git a/3rdparty/composable_kernel/profiler/src/profiler_operation_registry.hpp b/3rdparty/composable_kernel/profiler/src/profiler_operation_registry.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91ff291233066527609133dc155e67388119385a --- /dev/null +++ b/3rdparty/composable_kernel/profiler/src/profiler_operation_registry.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +class ProfilerOperationRegistry final +{ + ProfilerOperationRegistry() = default; + ~ProfilerOperationRegistry() = default; + + public: + using Operation = std::function; + + private: + struct Entry final + { + explicit Entry(std::string_view description, Operation operation) noexcept + : description_(description), operation_(std::move(operation)) + { + } + + std::string_view description_; + Operation operation_; + }; + + std::map entries_; + + friend std::ostream& operator<<(std::ostream& stream, const ProfilerOperationRegistry& registry) + { + stream << "{\n"; + for(auto& [name, entry] : registry.entries_) + { + stream << "\t" << name << ": " << entry.description_ << "\n"; + } + stream << "}"; + + return stream; + } + + public: + static ProfilerOperationRegistry& GetInstance() + { + static ProfilerOperationRegistry registry; + return registry; + } + + std::optional Get(std::string_view name) const + { + const auto found = entries_.find(name); + if(found == end(entries_)) + { + return std::nullopt; + } + + return (found->second).operation_; + } + + bool Add(std::string_view name, std::string_view description, Operation operation) + { + return entries_ + .emplace(std::piecewise_construct, + std::forward_as_tuple(name), + std::forward_as_tuple(description, std::move(operation))) + .second; + } +}; + +#define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y) +#define PP_CONCAT_IMPL(x, y) x##y + +#define REGISTER_PROFILER_OPERATION(name, description, operation) \ + static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ + ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) diff --git a/3rdparty/composable_kernel/rbuild.ini b/3rdparty/composable_kernel/rbuild.ini new file mode 100644 index 0000000000000000000000000000000000000000..3649cedf0aead2b76b9052125b9f2225330ce92b --- /dev/null +++ b/3rdparty/composable_kernel/rbuild.ini @@ -0,0 +1,8 @@ +[develop] +cxx = ${rocm_path}/bin/hipcc +cc = ${rocm_path}/llvm/bin/clang +ignore = pcre +deps = + -f dev-requirements.txt +define = + BUILD_DEV=On diff --git a/3rdparty/composable_kernel/requirements.txt b/3rdparty/composable_kernel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/3rdparty/composable_kernel/requirements.txt @@ -0,0 +1 @@ + diff --git a/3rdparty/composable_kernel/script/clang-format-overwrite.sh b/3rdparty/composable_kernel/script/clang-format-overwrite.sh new file mode 100644 index 0000000000000000000000000000000000000000..f9d11fcd8cb1144fc5e1091e848878e92878dbd2 --- /dev/null +++ b/3rdparty/composable_kernel/script/clang-format-overwrite.sh @@ -0,0 +1,2 @@ +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' diff --git a/3rdparty/composable_kernel/script/cmake-ck-dev.sh b/3rdparty/composable_kernel/script/cmake-ck-dev.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e605ce8def00add5c2390441ab7a7071d39f073 --- /dev/null +++ b/3rdparty/composable_kernel/script/cmake-ck-dev.sh @@ -0,0 +1,19 @@ +#!/bin/bash +rm -f CMakeCache.txt +rm -f *.cmake +rm -rf CMakeFiles + +MY_PROJECT_SOURCE=$1 + +cmake \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_BUILD_TYPE=Release \ +-D BUILD_DEV=ON \ +-D GPU_TARGETS="gfx908;gfx90a" \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +-D USE_BITINT_EXTENSION_INT4=OFF \ +${MY_PROJECT_SOURCE} + +#-D AMDGPU_TARGETS=gfx90a;gfx908 diff --git a/3rdparty/composable_kernel/script/cmake-ck-release.sh b/3rdparty/composable_kernel/script/cmake-ck-release.sh new file mode 100644 index 0000000000000000000000000000000000000000..268b1ebf9b7fa1743ef06b8cf3073dc44c0a7bc1 --- /dev/null +++ b/3rdparty/composable_kernel/script/cmake-ck-release.sh @@ -0,0 +1,19 @@ +#!/bin/bash +rm -f CMakeCache.txt +rm -f *.cmake +rm -rf CMakeFiles + +MY_PROJECT_SOURCE=$1 + +cmake \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +-D BUILD_DEV=OFF \ +-D GPU_TARGETS="gfx908;gfx90a" \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +-D USE_BITINT_EXTENSION_INT4=OFF \ +${MY_PROJECT_SOURCE} + +#-D AMDGPU_TARGETS=gfx90a;gfx908 diff --git a/3rdparty/composable_kernel/script/count_vgpr.sh b/3rdparty/composable_kernel/script/count_vgpr.sh new file mode 100644 index 0000000000000000000000000000000000000000..07debc53a8cf23433ab3fefa2280743171ceffe9 --- /dev/null +++ b/3rdparty/composable_kernel/script/count_vgpr.sh @@ -0,0 +1,20 @@ +#!/bin/bash +FILE=$1 + +for num in {0..255} +do + base_pattern="(\[?${num}\b|\[\d*:${num}\])" + spattern="s${base_pattern}" + vpattern="v${base_pattern}" + apattern="a${base_pattern}" + scount=$(grep -P $spattern $FILE | wc -l) + vcount=$(grep -P $vpattern $FILE | wc -l) + acount=$(grep -P $apattern $FILE | wc -l) + bash -c "echo -n v${num} $vcount && \ + echo -n , s${num} $scount && \ + echo -n , a${num} $acount" + if [[ $scount -ne 0 || $vcount -ne 0 || $acount -ne 0 ]]; then + echo -n " *" + fi + echo "" +done diff --git a/3rdparty/composable_kernel/script/hipclang_opt.sh b/3rdparty/composable_kernel/script/hipclang_opt.sh new file mode 100644 index 0000000000000000000000000000000000000000..c51bd51d9732ecc618e4ef4d59832be67fb7486a --- /dev/null +++ b/3rdparty/composable_kernel/script/hipclang_opt.sh @@ -0,0 +1,25 @@ +rm *.ll *.s + +BC_FILE=$1 + +/opt/rocm/llvm/bin/llvm-dis $BC_FILE -o original.ll +/opt/rocm/llvm/bin/opt -S -inline -inline-threshold=104857 original.ll > inline.ll +/opt/rocm/llvm/bin/opt -S -sroa inline.ll > sroa.ll +/opt/rocm/llvm/bin/opt -S -O3 sroa.ll > o3.ll + +/opt/rocm/llvm/bin/llc -mcpu=gfx906 original.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 inline.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 sroa.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 o3.ll + +#/opt/rocm/llvm/bin/opt -S -O3 -sroa inline.ll > o3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3.ll > o3_2.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_2.ll > o3_3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_3.ll > o3_4.ll + +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 opt.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 inline.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_2.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_4.ll diff --git a/3rdparty/composable_kernel/script/parse_perf_data.py b/3rdparty/composable_kernel/script/parse_perf_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb13e6243da23c5292096da5f3e389c8bf1a3cb --- /dev/null +++ b/3rdparty/composable_kernel/script/parse_perf_data.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +import os, io, argparse, datetime, re +import numpy as np +import sqlalchemy +from sqlalchemy.types import NVARCHAR, Float, Integer +import pymysql +import pandas as pd +from sshtunnel import SSHTunnelForwarder + +def print_to_string(*args, **kwargs): + output = io.StringIO() + print(*args, file=output, **kwargs) + contents = output.getvalue() + output.close() + return contents + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') + parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') + args = parser.parse_args() + files = [] + if os.path.isdir(args.filename): + all_files = os.listdir(args.filename) + for name in all_files: + if not 'log' in name: + continue + files.append(os.path.join(args.filename, name)) + else: + files = [args.filename] + args.files = files + return args + +def main(): + args = parse_args() + tests = [] + kernels=[] + tflops=[] + dtype=[] + alayout=[] + blayout=[] + M=[] + N=[] + K=[] + StrideA=[] + StrideB=[] + StrideC=[] + #parse results, get the Tflops value for "Best Perf" kernels + + glue="" + for filename in args.files: + for line in open(filename): + if 'Branch name' in line: + lst=line.split() + branch_name=lst[2] + if 'On branch' in line: + lst=line.split() + branch_name=lst[2] + if 'Node name' in line: + lst=line.split() + node_id=lst[2] + if 'GPU_arch' in line: + lst=line.split() + gpu_arch=lst[2] + if 'HIP version' in line: + lst=line.split() + hip_vers=lst[2] + if 'Compute Unit' in line: + lst=line.split() + compute_units=lst[2] + if 'InstalledDir' in line: + lst=line.split() + rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')] + print("Branch name:",branch_name) + print("Node name:",node_id) + print("GPU_arch:",gpu_arch) + print("Compute units:",compute_units) + print("ROCM_version:",rocm_vers) + print("HIP_version:",hip_vers) + + + #parse gemm performance tests: + if 'gemm' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + if len(lst)>=37: #the line is complete + tests.append(glue.join(lst[5:30])) + kernels.append(glue.join(lst[37:])) + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + elif len(lst)<37 and len(lst)>=33: #the tflops are available + tests.append(glue.join(lst[5:30])) + kernels.append("N/A") + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + print("warning: incomplete line:",lst) + elif len(lst)<33: #even the tflops are not available + print("Error in ckProfiler output!") + print("warning: incomplete line=",lst) + #sort results + #sorted_tests = sorted(tests) + #print("sorted tests:",sorted_tests) + sorted_tflops = [x for _,x in sorted(zip(tests,tflops))] + #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] + test_list=list(range(1,len(tests)+1)) + + #parse resnet50 performance tests: + if 'resnet50' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + tflops.append(lst[4]) + + print("Number of tests:",len(tflops)) + sql_hostname = '127.0.0.1' + sql_username = os.environ["dbuser"] + sql_password = os.environ["dbpassword"] + sql_main_database = 'miopen_perf' + sql_port = 3306 + ssh_host = os.environ["dbsship"] + ssh_user = os.environ["dbsshuser"] + ssh_port = int(os.environ["dbsshport"]) + ssh_pass = os.environ["dbsshpassword"] + + with SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_password=ssh_pass, + remote_bind_address=(sql_hostname, sql_port)) as tunnel: + + sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. + format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) + conn = sqlEngine.connect() + + #save gemm performance tests: + if 'gemm' in filename: + + #write the ck_gemm_test_params table + #only needed once the test set changes + ''' + sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] + sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] + sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] + sorted_M = [x for _,x in sorted(zip(tests,M))] + sorted_N = [x for _,x in sorted(zip(tests,N))] + sorted_K = [x for _,x in sorted(zip(tests,K))] + sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] + sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] + sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] + ck_gemm_params=[test_list,sorted_dtypes,sorted_alayout,sorted_blayout, + sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, + sorted_StrideC] + df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', + 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) + print(df) + + dtypes = { + 'Test_number': Integer(), + 'Data_type': NVARCHAR(length=5), + 'Alayout': NVARCHAR(length=12), + 'Blayout': NVARCHAR(length=12), + 'M': Integer(), + 'N': Integer(), + 'K': Integer(), + 'StrideA': Integer(), + 'StrideB': Integer(), + 'StrideC': Integer() + } + df.to_sql("ck_gemm_test_params",conn,if_exists='replace',index=False, dtype=dtypes) + ''' + + #read baseline results for the latest develop branch + query = '''SELECT * from ck_gemm_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_gemm_tflops where Branch_ID='develop' );''' + tflops_base = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,len(tests)+1): + testlist.append("Test%i"%i) + ck_gemm_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops=pd.DataFrame(data=[ck_gemm_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[sorted_tflops],columns=testlist) + flops=pd.concat([flops,df_add],axis=1) + print("new tflops for gemm tests:",flops) + flops.to_sql("ck_gemm_tflops",conn,if_exists='append',index=False) + + #save resnet50 performance tests: + if 'resnet50' in filename: + #read baseline results for the latest develop branch + query = '''SELECT * from ck_resnet50_N256_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N256_tflops where Branch_ID='develop' );''' + tflops_base_N256 = pd.read_sql_query(query, conn) + query = '''SELECT * from ck_resnet50_N4_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N4_tflops where Branch_ID='develop' );''' + tflops_base_N4 = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,50): + testlist.append("Layer%i"%i) + ck_resnet_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops0=pd.DataFrame(data=[ck_resnet_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[tflops[0:49]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=256 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N256_tflops",conn,if_exists='append',index=False) + df_add=pd.DataFrame(data=[tflops[49:98]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=4 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N4_tflops",conn,if_exists='append',index=False) + + conn.close() + + #compare the results to the baseline if baseline exists + regression=0 + if 'gemm' in filename: + if not tflops_base.empty: + base=tflops_base[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(sorted_tflops[i]): + print("test # ",i,"shows regression by {:.3f}%".format( + (float(sorted_tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(sorted_tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline") + if 'resnet50' in filename: + if not tflops_base_N256.empty: + base=tflops_base_N256[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=256") + if not tflops_base_N4.empty: + base=tflops_base_N4[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i+49]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i+49])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i+49])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=4") + + #return 0 if performance criteria met, otherwise return 1 + return regression + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/3rdparty/composable_kernel/script/process_perf_data.py b/3rdparty/composable_kernel/script/process_perf_data.py new file mode 100644 index 0000000000000000000000000000000000000000..638e4ef5644c7e814739893cfedabc6afeed28bb --- /dev/null +++ b/3rdparty/composable_kernel/script/process_perf_data.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +import os, io, argparse, datetime +#import numpy as np +import sqlalchemy +from sqlalchemy.types import NVARCHAR, Float, Integer +import pymysql +import pandas as pd +from sshtunnel import SSHTunnelForwarder + +def print_to_string(*args, **kwargs): + output = io.StringIO() + print(*args, file=output, **kwargs) + contents = output.getvalue() + output.close() + return contents + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') + parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') + args = parser.parse_args() + files = [] + if os.path.isdir(args.filename): + all_files = os.listdir(args.filename) + for name in all_files: + if not 'log' in name: + continue + files.append(os.path.join(args.filename, name)) + else: + files = [args.filename] + args.files = files + return args + +def get_log_params(logfile): + print("logfile=",logfile) + branch_name=' ' + node_id=' ' + gpu_arch=' ' + hip_vers=' ' + compute_units=0 + environment=' ' + rocm_vers=' ' + for line in open(logfile): + if 'Branch name' in line: + lst=line.split() + branch_name=lst[2] + if 'On branch' in line: + lst=line.split() + branch_name=lst[2] + if 'Node name' in line: + lst=line.split() + node_id=lst[2] + if 'GPU_arch' in line: + lst=line.split() + gpu_arch=lst[2] + if 'HIP version' in line: + lst=line.split() + hip_vers=lst[2] + if 'Compute Unit' in line: + lst=line.split() + compute_units=lst[2] + if 'Environment type' in line: + lst=line.split() + environment=lst[2] + if 'InstalledDir' in line: + lst=line.split() + rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')] + return branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment + +def parse_logfile(logfile): + glue='' + res=[] + tests=[] + kernels=[] + tflops=[] + dtype=[] + alayout=[] + blayout=[] + M=[] + N=[] + K=[] + StrideA=[] + StrideB=[] + StrideC=[] + if 'perf_gemm.log' in logfile: + for line in open(logfile): + if 'Best Perf' in line: + lst=line.split() + if len(lst)>=37: #the line is complete + tests.append(glue.join(lst[5:30])) + kernels.append(glue.join(lst[37:])) + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + elif len(lst)<37 and len(lst)>=33: #the tflops are available + tests.append(glue.join(lst[5:30])) + kernels.append("N/A") + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + print("warning: incomplete line:",lst) + elif len(lst)<33: #even the tflops are not available + print("Error in ckProfiler output!") + print("warning: incomplete line=",lst) + #sort results + #sorted_tests = sorted(tests) + res = [x for _,x in sorted(zip(tests,tflops))] + #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] + test_list=list(range(1,len(tests)+1)) + #parse conv_fwd and conv_bwd performance tests: + elif 'conv_fwd' in logfile or 'conv_bwd_data' in logfile: + for line in open(logfile): + if 'tflops:' in line: + lst=line.split() + res.append(lst[1]) + #parse all other performance tests: + elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile: + for line in open(logfile): + if 'Best Perf' in line: + lst=line.split() + res.append(lst[4]) + elif 'onnx_gemm' in logfile or 'splitK_gemm' in logfile: + for line in open(logfile): + if 'Best Perf' in line: + lst=line.split() + res.append(lst[33]) + return res + + +def get_baseline(table, connection): + query = '''SELECT * from '''+table+''' WHERE Datetime = (SELECT MAX(Datetime) FROM '''+table+''' where Branch_ID='develop' );''' + return pd.read_sql_query(query, connection) + +def store_new_test_result(table_name, test_results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, connection): + params=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(environment),str(datetime.datetime.now())] + df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime']) + df_add=pd.DataFrame(data=[test_results],columns=testlist) + df=pd.concat([df,df_add],axis=1) + #print("new test results dataframe:",df) + df.to_sql(table_name,connection,if_exists='append',index=False) + return 0 + +def compare_test_to_baseline(baseline,test,testlist): + regression=0 + if not baseline.empty: + base=baseline[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(test[i]): + print("test # ",i,"shows regression by {:.3f}%".format( + (float(test[i])-base_list[i])/base_list[i]*100)) + regression=1 + if base_list[i]>0: ave_perf=ave_perf+float(test[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline") + return regression + +''' +def post_test_params(tlist,connection): + sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] + sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] + sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] + sorted_M = [x for _,x in sorted(zip(tests,M))] + sorted_N = [x for _,x in sorted(zip(tests,N))] + sorted_K = [x for _,x in sorted(zip(tests,K))] + sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] + sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] + sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] + ck_gemm_params=[tlist,sorted_dtypes,sorted_alayout,sorted_blayout, + sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, + sorted_StrideC] + df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', + 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) + print(df) + + dtypes = { + 'Test_number': Integer(), + 'Data_type': NVARCHAR(length=5), + 'Alayout': NVARCHAR(length=12), + 'Blayout': NVARCHAR(length=12), + 'M': Integer(), + 'N': Integer(), + 'K': Integer(), + 'StrideA': Integer(), + 'StrideB': Integer(), + 'StrideC': Integer() + } + df.to_sql("ck_gemm_test_params",connection,if_exists='replace',index=False, dtype=dtypes) +''' + +def main(): + args = parse_args() + results=[] + tflops_base=[] + testlist=[] + #parse the test parameters from the logfile + for filename in args.files: + branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment = get_log_params(filename) + + print("Branch name:",branch_name) + print("Node name:",node_id) + print("GPU_arch:",gpu_arch) + print("Compute units:",compute_units) + print("ROCM_version:",rocm_vers) + print("HIP_version:",hip_vers) + print("Environment:",environment) + #parse results, get the Tflops value for "Best Perf" kernels + results=parse_logfile(filename) + + print("Number of tests:",len(results)) + sql_hostname = '127.0.0.1' + sql_username = os.environ["dbuser"] + sql_password = os.environ["dbpassword"] + sql_main_database = 'miopen_perf' + sql_port = 3306 + ssh_host = os.environ["dbsship"] + ssh_user = os.environ["dbsshuser"] + ssh_port = int(os.environ["dbsshport"]) + ssh_pass = os.environ["dbsshpassword"] + + with SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_password=ssh_pass, + remote_bind_address=(sql_hostname, sql_port)) as tunnel: + + sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. + format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) + conn = sqlEngine.connect() + + #save gemm performance tests: + if 'perf_gemm.log' in filename: + #write the ck_gemm_test_params table only needed once the test set changes + #post_test_params(test_list,conn) + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_gemm_tflops" + if 'batched_gemm' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_batched_gemm_tflops" + if 'grouped_gemm' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_grouped_gemm_tflops" + if 'conv_fwd' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_conv_fwd_tflops" + if 'conv_bwd_data' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_conv_bwd_data_tflops" + if 'gemm_bilinear' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_gemm_bilinear_tflops" + if 'reduction' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_reduction_GBps" + if 'resnet50_N4' in filename: + for i in range(1,50): + testlist.append("Layer%i"%i) + table_name="ck_resnet50_N4_tflops" + if 'resnet50_N256' in filename: + for i in range(1,50): + testlist.append("Layer%i"%i) + table_name="ck_resnet50_N256_tflops" + if 'onnx_gemm' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_onnx_gemm_tflops" + if 'splitK_gemm' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_splitK_gemm_tflops" + + tflops_base = get_baseline(table_name,conn) + store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) + conn.close() + + #compare the results to the baseline if baseline exists + regression=0 + regression=compare_test_to_baseline(tflops_base,results,testlist) + return regression + +if __name__ == '__main__': + main() diff --git a/3rdparty/composable_kernel/script/process_perf_data.sh b/3rdparty/composable_kernel/script/process_perf_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..15fc5cb15f8b3f4ab5598afec87796fea5e7e80c --- /dev/null +++ b/3rdparty/composable_kernel/script/process_perf_data.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# +# in order to run this script you'd need the following python packages: + +#pip3 install --upgrade pip +#pip3 install sqlalchemy pymysql pandas sshtunnel + +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details + +#process results +python3 process_perf_data.py perf_gemm.log +python3 process_perf_data.py perf_resnet50_N256.log +python3 process_perf_data.py perf_resnet50_N4.log diff --git a/3rdparty/composable_kernel/script/process_qa_data.sh b/3rdparty/composable_kernel/script/process_qa_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..abf1e6234e29ccf83ec9bb425a854b58ad5effdc --- /dev/null +++ b/3rdparty/composable_kernel/script/process_qa_data.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# +# in order to run this script you'd need the following python packages: + +#pip3 install --upgrade pip +#pip3 install sqlalchemy pymysql pandas sshtunnel + +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details + +#process results +python3 process_perf_data.py perf_gemm.log +python3 process_perf_data.py perf_resnet50_N256.log +python3 process_perf_data.py perf_resnet50_N4.log +python3 process_perf_data.py perf_batched_gemm.log +python3 process_perf_data.py perf_grouped_gemm.log +python3 process_perf_data.py perf_conv_fwd.log +python3 process_perf_data.py perf_conv_bwd_data.log +python3 process_perf_data.py perf_gemm_bilinear.log +python3 process_perf_data.py perf_reduction.log +python3 process_perf_data.py perf_splitK_gemm.log +python3 process_perf_data.py perf_onnx_gemm.log diff --git a/3rdparty/composable_kernel/script/profile_batched_gemm.sh b/3rdparty/composable_kernel/script/profile_batched_gemm.sh new file mode 100644 index 0000000000000000000000000000000000000000..d19ddd0c6525dd0737187f904719ac4586e0722a --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_batched_gemm.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 -1 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 -1 -1 -1 -1 4 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 -1 -1 -1 -1 2 + + ####### op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1024 1024 1024 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2048 2048 2048 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4096 4096 4096 -1 -1 -1 4 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8192 8192 8192 -1 -1 -1 2 + + ####### op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4128 4128 4128 -1 -1 -1 4 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8224 8224 8224 -1 -1 -1 2 + + ####### op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 -1 -1 -1 8 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4160 4160 4160 -1 -1 -1 4 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8256 8256 8256 -1 -1 -1 2 diff --git a/3rdparty/composable_kernel/script/profile_conv_bwd_data.sh b/3rdparty/composable_kernel/script/profile_conv_bwd_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1d2f450c96185b3f014efc388620b054677bc7c --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_conv_bwd_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + + N=$8 + +# Resnet50 +######## op datatype layout verify init log time conv_dim G__ N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 diff --git a/3rdparty/composable_kernel/script/profile_conv_fwd.sh b/3rdparty/composable_kernel/script/profile_conv_fwd.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1d2f450c96185b3f014efc388620b054677bc7c --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_conv_fwd.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + + N=$8 + +# Resnet50 +######## op datatype layout verify init log time conv_dim G__ N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2 1 $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 diff --git a/3rdparty/composable_kernel/script/profile_gemm.sh b/3rdparty/composable_kernel/script/profile_gemm.sh new file mode 100644 index 0000000000000000000000000000000000000000..b88159e74d79d5f3648ca3777fc36713dee77dcd --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_gemm.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + + +# 120 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 2048 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 1024 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 -1 + +# 104 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 832 1024 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 832 2048 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1664 1024 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1664 2048 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3328 4096 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 6656 8192 8192 -1 -1 -1 + +# 110 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1280 1408 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1280 2816 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2560 1408 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2560 2816 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 5120 5632 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7040 8192 8192 -1 -1 -1 + +# testing different strides +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1024 1024 1024 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2048 2048 2048 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4096 4096 4096 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8192 8192 8192 + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4128 4128 4128 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8224 8224 8224 + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4160 4160 4160 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8256 8256 8256 diff --git a/3rdparty/composable_kernel/script/profile_gemm_bilinear.sh b/3rdparty/composable_kernel/script/profile_gemm_bilinear.sh new file mode 100644 index 0000000000000000000000000000000000000000..e6edefae85bd082a716831b2c105b303a313e6e6 --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_gemm_bilinear.sh @@ -0,0 +1,41 @@ +#!/bin/bash +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 -1 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 -1 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 -1 -1 1 1 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 0 -1 1 1 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1000 1000 1000 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2000 2000 2000 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4000 4000 4000 -1 -1 0 -1 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8000 8000 8000 -1 -1 0 -1 1 1 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 1056 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 2080 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4128 4128 4128 4128 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8224 8224 8224 8224 1 1 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 1088 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 2112 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4160 4160 4160 4160 1 1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8256 8256 8256 8256 1 1 \ No newline at end of file diff --git a/3rdparty/composable_kernel/script/profile_grouped_gemm.sh b/3rdparty/composable_kernel/script/profile_grouped_gemm.sh new file mode 100644 index 0000000000000000000000000000000000000000..8adb7c81ace83a1febc773d3b3424f09533847ce --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_grouped_gemm.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 + +######## op datatype layout verify init log time Ms______________ Ns______________ Ks_____________ StrideAs___________ StrideBs__________ StrideCs___________ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 256,512,1024,768 128,256,384,1024 128,192,256,512 1024,1025,1044,1026 1024,1024,1024,1024 1025,1024,1028,1024 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 512,768,2048,128 128,256,384,1024 128,192,256,512 1024,1025,2053,1026 1024,1024,1024,1024 1025,1024,2054,1024 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 256,512,1024,768 512,256,768,1024 128,192,256,512 1024,1045,1034,1026 1024,1024,1024,1024 1025,1063,1028,1024 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 512,768,4096,768 128,768,512,2048 128,192,256,512 1024,1027,4096,2050 1024,1024,1024,2048 1025,1024,4099,2049 diff --git a/3rdparty/composable_kernel/script/profile_onnx_gemm.sh b/3rdparty/composable_kernel/script/profile_onnx_gemm.sh new file mode 100644 index 0000000000000000000000000000000000000000..c2721e7f59aea4150cbb5944e831326d73d7ed53 --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_onnx_gemm.sh @@ -0,0 +1,31 @@ +#!/bin/bash +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 +# GEMM kernel benchmarks used by ONNX +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 768 768 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 768 2304 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 768 3072 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 3072 768 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 1024 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 1024 3072 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 1024 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 384 4096 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 768 768 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 768 2304 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 768 3072 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 3072 768 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 1024 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 1024 3072 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 1024 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 24576 4096 1024 -1 -1 -1 + diff --git a/3rdparty/composable_kernel/script/profile_reduce_no_index.sh b/3rdparty/composable_kernel/script/profile_reduce_no_index.sh new file mode 100644 index 0000000000000000000000000000000000000000..66bfe1dcd34c6183b8ff64a9c5c39b7e53c60488 --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_reduce_no_index.sh @@ -0,0 +1,78 @@ +#!/bin/bash +DRIVER="../build/bin/ckProfiler" +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 +PRECISION=$4 +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +if [ -n $PRECISION ] && [ "$PRECISION" = "--half" -o "$PRECISION" = "--bf16" ]; then + ACCTYPE="-C 1" +elif [ -n $PRECISION ] && [ "$PRECISION" = "--int8" ]; then + ACCTYPE="-C 2" +fi + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations="0 5" + +#### 0 - ADD, 5 - AVG, for int8, no NORM2 supported +if [ -n $PRECISION ] && [ "$PRECISION" = "--int8" -o "$PRECISION" = "--half" ]; then + Operations=5 +fi + +## for generic validation +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,22960 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,22960 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 4,1469440 -R 0 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 4,1469440 -R 1 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + +#### 0 - ADD, 5 - AVG, 7 - NORM2 +Operations=5 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + set -x + ####### datatype layout reduce dims op acctype verify init repeats + $DRIVER reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op $ACCTYPE $VERIFY $INIT $NREPEAT + set +x +done + diff --git a/3rdparty/composable_kernel/script/profile_reduce_with_index.sh b/3rdparty/composable_kernel/script/profile_reduce_with_index.sh new file mode 100644 index 0000000000000000000000000000000000000000..43543f4430ba1d177d35453ff49d563d5e67754d --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_reduce_with_index.sh @@ -0,0 +1,70 @@ +#!/bin/bash +DRIVER="../build/bin/ckProfiler" +VERIFY="-v $1" +INIT=$2 +NREPEAT=$3 +PRECISION=$4 +##PRECISION=--half +##PRECISION=--double +##PRECISION=--int8 +##PRECISION=--bf16 + +#### 2 - MIN, 3 - MAX, 4 - AMAX +Operations="2 4" + +## for generic validation +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 1,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,2,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 64,4,280,82 -R 0,1,3 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,22960 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,22960 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 4,1469440 -R 0 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 4,1469440 -R 1 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + +Operations=2 + +## for performance evaluation (resnet50 NHWC => C) +for op in $Operations; do + for use_idx in 0 1; do + set -x + ####### datatype layout reduce dims op use index verify init repeats + $DRIVER reduce $PRECISION -D 256,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 256,230,230,3 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,14,14,1024 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,28,28,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,58,58,128 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,7,7,2048 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,14,14,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,30,30,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,56,56,256 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,16,16,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,28,28,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,7,7,512 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + $DRIVER reduce $PRECISION -D 128,56,56,64 -R 0,1,2 -O $op -I $use_idx $VERIFY $INIT $NREPEAT + set +x + done +done + diff --git a/3rdparty/composable_kernel/script/profile_resnet50.sh b/3rdparty/composable_kernel/script/profile_resnet50.sh new file mode 100644 index 0000000000000000000000000000000000000000..b55cb2cceff2fd0b21b0bac9a0ab0aca25737897 --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_resnet50.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +IN_LAYOUT=$3 +WEI_LAYOUT=$4 +OUT_LAYOUT=$5 +VERIFY=$6 +INIT=$7 +LOG=$8 +TIME=$9 + + N=${10} + +# Resnet50 +######## op____________________ datatype in_layout wei_layout out_layout verify init log time N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $TIME $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 diff --git a/3rdparty/composable_kernel/script/profile_splitK_gemm.sh b/3rdparty/composable_kernel/script/profile_splitK_gemm.sh new file mode 100644 index 0000000000000000000000000000000000000000..d62f0e4753dca2be3004433adaa15258340df232 --- /dev/null +++ b/3rdparty/composable_kernel/script/profile_splitK_gemm.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 +KBatch=$8 + + +# 120 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 2048 2048 -1 -1 -1 $KBatch + +# 104 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 832 1024 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 832 2048 2048 -1 -1 -1 $KBatch + +# 110 CU +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1280 1408 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1280 2816 2048 -1 -1 -1 $KBatch + +# testing different strides +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1024 1024 1024 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2048 2048 2048 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 $KBatch diff --git a/3rdparty/composable_kernel/script/run_full_performance_tests.sh b/3rdparty/composable_kernel/script/run_full_performance_tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..eae334ae2dd57c26983fe3c070cb43c6acd34420 --- /dev/null +++ b/3rdparty/composable_kernel/script/run_full_performance_tests.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details +# +# run the script as "./run_full_performance_tests.sh < node name> +# input arguments: +# verification = 0 : do not verify result correctness on CPU +# = 1 : verifuy correctness on CPU (may take a long time) +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# node name : $hostname + +#get the command line arguments: +export verify=$1 +echo 'Verification: ' $verify +export env_type=$2 +echo 'Environment type: ' $env_type +export branch=$3 +echo 'Branch name: ' $branch +export host_name=$4 +echo 'Host name: ' $host_name +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run gemm tests +export gemm_log="perf_gemm.log" +print_log_header $gemm_log $env_type $branch $host_name +./profile_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 0 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 0 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 1 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 1 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 1 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 1 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 2 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 2 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 2 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 2 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 3 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 3 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 3 $verify 1 0 1 2>&1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 3 $verify 1 0 1 2>&1 | tee -a $gemm_log + +#run batched_gemm tests +export batched_gemm_log="perf_batched_gemm.log" +print_log_header $batched_gemm_log $env_type $branch $host_name +./profile_batched_gemm.sh batched_gemm 0 0 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 0 1 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 0 2 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 0 3 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 1 0 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 1 1 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 1 2 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 1 3 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 2 0 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 2 1 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 2 2 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 2 3 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 3 0 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 3 1 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 3 2 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log +./profile_batched_gemm.sh batched_gemm 3 3 $verify 1 0 1 2>&1 | tee -a $batched_gemm_log + +#run grouped_gemm tests +export grouped_gemm_log="perf_grouped_gemm.log" +print_log_header $grouped_gemm_log $env_type $branch $host_name +./profile_grouped_gemm.sh grouped_gemm 1 0 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log +./profile_grouped_gemm.sh grouped_gemm 1 1 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log +./profile_grouped_gemm.sh grouped_gemm 1 2 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log +./profile_grouped_gemm.sh grouped_gemm 1 3 $verify 1 0 1 2>&1 | tee -a $grouped_gemm_log + +#run GEMM+Bilinear tests +export gemm_bilinear_log="perf_gemm_bilinear.log" +print_log_header $gemm_bilinear_log $env_type $branch $host_name +./profile_gemm_bilinear.sh gemm_bilinear 1 0 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log +./profile_gemm_bilinear.sh gemm_bilinear 1 1 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log +./profile_gemm_bilinear.sh gemm_bilinear 1 2 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log +./profile_gemm_bilinear.sh gemm_bilinear 1 3 $verify 1 0 1 2>&1 | tee -a $gemm_bilinear_log + +#run conv_fwd tests +export conv_fwd_log="perf_conv_fwd.log" +print_log_header $conv_fwd_log $env_type $branch $host_name +./profile_conv_fwd.sh conv_fwd 0 1 $verify 1 0 1 256 2>&1 | tee -a $conv_fwd_log +./profile_conv_fwd.sh conv_fwd 1 1 $verify 1 0 1 256 2>&1 | tee -a $conv_fwd_log +./profile_conv_fwd.sh conv_fwd 2 1 $verify 1 0 1 256 2>&1 | tee -a $conv_fwd_log +./profile_conv_fwd.sh conv_fwd 3 1 $verify 1 0 1 256 2>&1 | tee -a $conv_fwd_log + +#run conv_bwd_data tests +export conv_bwd_data_log="perf_conv_bwd_data.log" +print_log_header $conv_bwd_data_log $env_type $branch $host_name +./profile_conv_bwd_data.sh conv_bwd_data 0 1 $verify 1 0 1 256 2>&1 | tee -a $conv_bwd_data_log +./profile_conv_bwd_data.sh conv_bwd_data 1 1 $verify 1 0 1 256 2>&1 | tee -a $conv_bwd_data_log +./profile_conv_bwd_data.sh conv_bwd_data 2 1 $verify 1 0 1 256 2>&1 | tee -a $conv_bwd_data_log +./profile_conv_bwd_data.sh conv_bwd_data 3 1 $verify 1 0 1 256 2>&1 | tee -a $conv_bwd_data_log + +#run resnet50 tests +export resnet256_log="perf_resnet50_N256.log" +print_log_header $resnet256_log $env_type $branch $host_name +./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 256 2>&1 | tee -a $resnet256_log +export resnet4_log="perf_resnet50_N4.log" +print_log_header $resnet4_log $env_type $branch $host_name +./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 4 2>&1 | tee -a $resnet4_log + +#run reduction tests +export reduction_log="perf_reduction.log" +print_log_header $reduction_log $env_type $branch $host_name +./profile_reduce_with_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log +./profile_reduce_no_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log + +#run splitK_gemm tests, first correctness verification, then performance +export splitK_gemm_ver_log="perf_splitK_gemm_verify.log" +print_log_header $splitK_gemm_ver_log $env_type $branch $host_name +./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 0 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 0 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 1 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 1 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 1 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log +export splitK_gemm_log="perf_splitK_gemm.log" +print_log_header $splitK_gemm_log $env_type $branch $host_name +./profile_splitK_gemm.sh gemm_splitk 0 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log + +#run ONNX gemm tests +export onnx_log="perf_onnx_gemm.log" +print_log_header $onnx_log $env_type $branch $host_name +./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log +./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log diff --git a/3rdparty/composable_kernel/script/run_performance_tests.sh b/3rdparty/composable_kernel/script/run_performance_tests.sh new file mode 100644 index 0000000000000000000000000000000000000000..4e3a6fc8eb675095fce13a9a97ce9df2d6d8630d --- /dev/null +++ b/3rdparty/composable_kernel/script/run_performance_tests.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ +# run the script as "./run_performance_tests.sh < node name> +# input arguments: +# verification = 0 : do not verify result correctness on CPU +# = 1 : verify correctness on CPU (may take a long time) +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# node name : $hostname + +#get the command line arguments: +export verify=$1 +echo 'Verification: ' $verify +export env_type=$2 +echo 'Environment type: ' $env_type +export branch=$3 +echo 'Branch name: ' $branch +export host_name=$4 +echo 'Host name: ' $host_name + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run gemm tests +export gemm_log="perf_gemm.log" +print_log_header $gemm_log $env_type $branch $host_name +./profile_gemm.sh gemm 0 0 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 0 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 0 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 0 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 1 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 1 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 1 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 1 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 2 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 2 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 2 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 2 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 0 3 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 1 3 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 2 3 $verify 1 0 1 | tee -a $gemm_log +./profile_gemm.sh gemm 3 3 $verify 1 0 1 | tee -a $gemm_log + +#run resnet50 tests +export resnet256_log="perf_resnet50_N256.log" +print_log_header $resnet256_log $env_type $branch $host_name +./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 256 | tee -a $resnet256_log +export resnet4_log="perf_resnet50_N4.log" +print_log_header $resnet4_log $env_type $branch $host_name +./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1 $verify 1 0 1 4 | tee -a $resnet4_log diff --git a/3rdparty/composable_kernel/script/test_convnd_fwd.sh b/3rdparty/composable_kernel/script/test_convnd_fwd.sh new file mode 100644 index 0000000000000000000000000000000000000000..1bd7a6b5d71dfdc185b287a47fbda2dd1b7b4d86 --- /dev/null +++ b/3rdparty/composable_kernel/script/test_convnd_fwd.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash + +# set -e + +DIM1=False +DIM2=True +DIM3=False +DATE=220317 +GIT_HASH=4e6dfda +LOG_DIR=${DATE}_${GIT_HASH} +SUFFIX=${GIT_HASH} + + +#-------------------------------------------------------------------------- +# Commandline arguments parsing +# like: cmd -key[--key] value +#-------------------------------------------------------------------------- + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" + +case $key in + -d1|--d1) + DIM1=True + echo DIM1: "${DIM1}" + shift # past argument + ;; + -d2|--d2) + DIM2=True + echo DIM2: "${DIM2}" + shift # past argument + ;; + -d3|--d3) + DIM3=True + echo DIM3: "${DIM3}" + shift # past argument + ;; + -all|--all) + DIM1=True + DIM2=True + DIM3=True + echo DIM1: "${DIM1}" + echo DIM2: "${DIM2}" + echo DIM3: "${DIM3}" + shift # past argument + ;; + -s|--suffix) + SUFFIX=${SUFFIX}_"$2" + echo SUFFIX: "${SUFFIX}" + shift # past argument + shift # past value + ;; + *) # unknown option + POSITIONAL+=("$1") # save it in an array for later + shift # past argument + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +#-------------------------------------------------------------------------- + +# NUMACTL="numactl --cpunodebind=1 --membind=1" +NUMACTL= +# ENV_CONF= +GPU=mi100 +PROF_ITER_COUNT=10000 +LOG_DIR_PATH=../log/${LOG_DIR} +set -x + +#------------------------------------------------------------------------------- +# 1D +#------------------------------------------------------------------------------- + +if [[ "${DIM1}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv1d nwc <<<<<<<<<<" + CMD="./../build/bin/test_conv1d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv1d_fwd_nwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 2D +#------------------------------------------------------------------------------- + +if [[ "${DIM2}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv2d nhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv2d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv2d_fwd_nhwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 3D +#------------------------------------------------------------------------------- + +if [[ "${DIM3}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv3d ndhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv3d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv3d_fwd_ndhwc_${SUFFIX}_${GPU}.log + +fi diff --git a/3rdparty/composable_kernel/script/test_reduce_no_index.sh b/3rdparty/composable_kernel/script/test_reduce_no_index.sh new file mode 100644 index 0000000000000000000000000000000000000000..b956303837089d16f055fd4cc68088923b7eed36 --- /dev/null +++ b/3rdparty/composable_kernel/script/test_reduce_no_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/3rdparty/composable_kernel/script/test_reduce_with_index.sh b/3rdparty/composable_kernel/script/test_reduce_with_index.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0843ba6c1be5d2c80222abb7edc160511721099 --- /dev/null +++ b/3rdparty/composable_kernel/script/test_reduce_with_index.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +## The following will be used for CI + +set -x + +## for float +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 0 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 0 2 + +## for float64 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 6 2 + +## for float16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 1 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 1 2 + +## for int8_t +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 3 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 3 2 + +## for bfloat16 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 5 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 5 2 + +set +x + diff --git a/3rdparty/composable_kernel/test/CMakeLists.txt b/3rdparty/composable_kernel/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b2e25e4ca76b7c6cb87f62e6dd6203ff4054045e --- /dev/null +++ b/3rdparty/composable_kernel/test/CMakeLists.txt @@ -0,0 +1,60 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/ + ${PROJECT_SOURCE_DIR}/profiler/include +) + +include(googletest) + +add_custom_target(tests) + +function(add_test_executable TEST_NAME) + message("adding test ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_test(NAME ${TEST_NAME} COMMAND $) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) +endfunction(add_test_executable TEST_NAME) + +include(GoogleTest) + +function(add_gtest_executable TEST_NAME) + message("adding gtest ${TEST_NAME}") + add_executable(${TEST_NAME} ${ARGN}) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + + # suppress gtest warnings + target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(${TEST_NAME} PRIVATE gtest_main) + add_test(NAME ${TEST_NAME} COMMAND $ ) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) +endfunction(add_gtest_executable TEST_NAME) + +add_subdirectory(magic_number_division) +add_subdirectory(space_filling_curve) +add_subdirectory(conv_util) +add_subdirectory(reference_conv_fwd) +add_subdirectory(gemm) +add_subdirectory(gemm_split_k) +add_subdirectory(gemm_reduce) +add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(batched_gemm_gemm) +add_subdirectory(batched_gemm_softmax_gemm) +add_subdirectory(batched_gemm_softmax_gemm_permute) +add_subdirectory(grouped_gemm) +add_subdirectory(reduce) +add_subdirectory(convnd_fwd) +add_subdirectory(convnd_bwd_data) +add_subdirectory(grouped_convnd_fwd) +add_subdirectory(grouped_convnd_bwd_weight) +add_subdirectory(block_to_ctile_map) +add_subdirectory(softmax) +add_subdirectory(normalization) +add_subdirectory(data_type) +add_subdirectory(elementwise_normalization) +add_subdirectory(batchnorm) +if(GPU_TARGETS MATCHES "gfx1100") + add_subdirectory(wmma_op) +endif() diff --git a/3rdparty/composable_kernel/test/batched_gemm/CMakeLists.txt b/3rdparty/composable_kernel/test/batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0574f98e872dcb436b14df3f536f06a787d28c8d --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm/CMakeLists.txt @@ -0,0 +1,15 @@ +add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) +target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) + +add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) +target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) +target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) + +add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) +target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) +target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) + +add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) +target_link_libraries(test_batched_gemm_int8 PRIVATE utility) +target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) diff --git a/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_bf16.cpp b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..78be540627850426a075e2bde1516410ef070756 --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_bf16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 256; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp16.cpp b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6cbbedf6774d2f90f85e5128c5bae8fbe912dc8e --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 512; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp32.cpp b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9e565e264b493d6f7f1b248fe6d307150a4c989 --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_fp32.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = float; +using BDataType = float; +using CDataType = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 256; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_int8.cpp b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4da941a5766bc6decc5adc938b4ce39da5d8a6db --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm/batched_gemm_int8.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/profile_batched_gemm_impl.hpp" + +namespace { +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +int main() +{ + int M = 256; + int N = 256; + int K = 128; + int BatchCount = 3; + + bool pass = true; + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + + pass = pass && + ck::profiler::profile_batched_gemm_impl( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + + std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl; + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_gemm/CMakeLists.txt b/3rdparty/composable_kernel/test/batched_gemm_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..386809717f242c5a23dd3c49437b4ebf055243eb --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +add_custom_target(test_batched_gemm_gemm) + +add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) +add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) \ No newline at end of file diff --git a/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa113de219437e900b2de8739ddaffb2caf3cefa --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_gemm_util.hpp" + +template +class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, + std::tuple + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 1}, + {128, 128, 136, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 1}, + {128, 128, 129, 128, 1}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 768}, + {256, 256, 128, 128, 768}, + {512, 512, 64, 64, 768}, + {512, 512, 128, 128, 768}, + {1024, 1024, 64, 64, 768}, + {1024, 1024, 128, 128, 768}, + {2048, 2048, 64, 64, 768}, + {2048, 2048, 128, 128, 768}, + {4096, 4096, 64, 64, 768}, + {4096, 4096, 128, 128, 768}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp b/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..53c4d37c44781be976f917362d0c433d1ec1711f --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" +#include "profiler/profile_batched_gemm_gemm_impl.hpp" + +using ck::tensor_operation::device::GemmSpecialization; + +template +using I = ck::Number; + +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct TestBatchedGemmGemm : public ::testing::Test +{ + using ADataType = std::tuple_element_t<0, Tuple>; + using B0DataType = std::tuple_element_t<1, Tuple>; + using B1DataType = std::tuple_element_t<2, Tuple>; + using CDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using B0Layout = std::tuple_element_t<5, Tuple>; + using B1Layout = std::tuple_element_t<6, Tuple>; + using CLayout = std::tuple_element_t<7, Tuple>; + + std::vector> lengths_ = { + {256, 256, 64, 64, 4}, + {256, 256, 128, 128, 4}, + {512, 512, 64, 64, 2}, + {512, 512, 128, 128, 2}, + {1024, 1024, 64, 64, 1}, + {1024, 1024, 128, 128, 1}, + }; + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int BatchCount) + { + bool pass = ck::profiler::profile_batched_gemm_gemm_impl( + verify_, 1, false, bench_, M, N, K, O, BatchCount); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int BatchCount = lengths[4]; + + this->RunSingle(M, N, K, O, BatchCount); + } + } +}; + +template +struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ALayout = Row; + using B0Layout = Col; + using B1Layout = Row; + using CLayout = Row; + + using ADataType = F16; + using B0DataType = F16; + using B1DataType = F16; + using AccDataType = float; + using CShuffleDataType = float; + using CDataType = F16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + template + using S = ck::Sequence; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + + bool IsSupported(int M, int N, int K, int O) + { + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + M, + N, + K, + O, + 0, // BatchCount + 0, // StrideA + 0, // StrideB0 + 0, // StrideB1 + 0, // StrideC + 0, // BatchStrideA + 0, // BatchStrideB0 + 0, // BatchStrideB1 + 0, // BatchStrideC + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + PassThrough{}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; diff --git a/3rdparty/composable_kernel/test/batched_gemm_reduce/CMakeLists.txt b/3rdparty/composable_kernel/test/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4dc0b0825744294d219a39d46ec4f8aec06fdf93 --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) diff --git a/3rdparty/composable_kernel/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/3rdparty/composable_kernel/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b150ce50d166d5a673c6b34c468b7988f3b4d91f --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "profiler/profile_batched_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + int BatchCount = 3; + + bool pass = true; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, false, M, N, K, M, K, N, BatchCount); + + if(pass) + { + std::cout << "test BatchedGEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test BatchedGEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/CMakeLists.txt b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ceecefb5f2c0a86e6d951948c84a2286023773f --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +add_custom_target(test_batched_gemm_softmax_gemm) + +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) +add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) \ No newline at end of file diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5df7769d5f675e906e350738751923d26b6af39d --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_softmax_gemm_util.hpp" + +template +class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm +{ +}; + +using Masked = std::true_type; +using NoMask = std::false_type; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, + std::tuple + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 1}, + {128, 128, 136, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 1}, + {128, 128, 129, 128, 1}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 768}, + {256, 256, 128, 128, 768}, + {512, 512, 64, 64, 768}, + {512, 512, 128, 128, 768}, + {1024, 1024, 64, 64, 768}, + {1024, 1024, 128, 128, 768}, + {2048, 2048, 64, 64, 768}, + {2048, 2048, 128, 128, 768}, + {4096, 4096, 64, 64, 768}, + {4096, 4096, 128, 128, 768}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK) +{ + this->lengths_ = std::vector>{{256, 256, 160, 160, 16}, + {256, 64, 160, 64, 16}, + {1024, 1024, 80, 80, 16}, + {1024, 64, 80, 64, 16}, + {4096, 4096, 40, 40, 16}, + {4096, 64, 40, 64, 16}}; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest) +{ + this->lengths_ = std::vector>{ + {49, 49, 64, 64, 24}, + {64, 49, 64, 64, 24}, + {1020, 1020, 64, 128, 24}, + {576, 576, 64, 64, 24}, + }; + this->Run(); +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98debe19c3c0b419f46dd08321f43f846536772a --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "profiler/profile_batched_gemm_softmax_gemm_impl.hpp" +using ck::tensor_operation::device::GemmSpecialization; + +template +using I = ck::Number; + +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct TestBatchedGemmSoftmaxGemm : public ::testing::Test +{ + using ADataType = std::tuple_element_t<0, Tuple>; + using B0DataType = std::tuple_element_t<1, Tuple>; + using B1DataType = std::tuple_element_t<2, Tuple>; + using CDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using B0Layout = std::tuple_element_t<5, Tuple>; + using B1Layout = std::tuple_element_t<6, Tuple>; + using CLayout = std::tuple_element_t<7, Tuple>; + using MaskingType = std::tuple_element_t<8, Tuple>; + + std::vector> lengths_ = {{256, 256, 64, 64, 4}, + {256, 256, 128, 128, 4}, + {512, 512, 64, 64, 2}, + {512, 512, 128, 128, 2}, + {1024, 1024, 64, 64, 1}, + {1024, 1024, 128, 128, 1}, + {256, 256, 160, 160, 4}, + {256, 64, 160, 64, 4}, + {1024, 1024, 80, 80, 2}, + {1024, 64, 80, 64, 2}, + {4096, 4096, 40, 40, 1}, + {4096, 64, 40, 64, 1}}; + + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int BatchCount) + { + bool pass = ck::profiler::profile_batched_gemm_softmax_gemm_impl( + verify_, 1, false, bench_, M, N, K, O, BatchCount); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int BatchCount = lengths[4]; + + this->RunSingle(M, N, K, O, BatchCount); + } + } +}; + +template +struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ALayout = Row; + using B0Layout = Col; + using B1Layout = Row; + using CLayout = Row; + + using ADataType = F16; + using B0DataType = F16; + using B1DataType = F16; + using AccDataType = float; + using CShuffleDataType = float; + using CDataType = F16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + template + using S = ck::Sequence; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + false>; + + bool IsSupported(int M, int N, int K, int O) + { + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + M, + N, + K, + O, + 0, // BatchCount + 0, // StrideA + 0, // StrideB0 + 0, // StrideB1 + 0, // StrideC + 0, // BatchStrideA + 0, // BatchStrideB0 + 0, // BatchStrideB1 + 0, // BatchStrideC + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + PassThrough{}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f858d9f2087632ad3309ccc5cf7a3e851bca6c4b --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -0,0 +1,8 @@ +add_custom_target(test_batched_gemm_softmax_gemm_permute) + +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) +target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) +target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) \ No newline at end of file diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a775e6d928d646e99887eac3dd174fee8450f84 --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_softmax_gemm_permute_util.hpp" + +template +class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16 + : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute +{ +}; + +using I1_t = ck::Number<1>; +using I2_t = ck::Number<2>; + +using MaskDisabled_t = + ck::integral_constant; +using MaskOutUpperTriangle_t = + ck::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple<>, MaskDisabled_t>, + std::tuple, ck::Tuple<>, MaskOutUpperTriangle_t> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes); + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Test_BF16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 3, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 2, 4}, + {128, 128, 136, 128, 4, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 4, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 2, 3}, + {128, 128, 129, 128, 2, 3}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16_IrregularK) +{ + this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, + {256, 64, 160, 64, 1, 16}, + {1024, 1024, 80, 80, 1, 16}, + {1024, 64, 80, 64, 1, 16}, + {4096, 4096, 40, 40, 1, 16}, + {4096, 64, 40, 64, 1, 16}}; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, AdhocTest) +{ + this->lengths_ = std::vector>{ + {49, 49, 64, 64, 4, 6}, + {64, 49, 64, 64, 4, 6}, + {1020, 1020, 64, 128, 4, 6}, + {576, 576, 64, 64, 4, 6}, + }; + this->Run(); +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..293acd60155ace98eaea89a11eb03fa6710fd877 --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_softmax_gemm_permute_util.hpp" + +template +class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 + : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute +{ +}; + +using I1_t = ck::Number<1>; +using I2_t = ck::Number<2>; + +using MaskDisabled_t = + ck::integral_constant; +using MaskOutUpperTriangle_t = + ck::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple<>, MaskDisabled_t>, + std::tuple, ck::Tuple<>, MaskOutUpperTriangle_t> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 3, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 2, 4}, + {128, 128, 136, 128, 4, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 4, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 2, 3}, + {128, 128, 129, 128, 2, 3}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK) +{ + this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, + {256, 64, 160, 64, 1, 16}, + {1024, 1024, 80, 80, 1, 16}, + {1024, 64, 80, 64, 1, 16}, + {4096, 4096, 40, 40, 1, 16}, + {4096, 64, 40, 64, 1, 16}}; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest) +{ + this->lengths_ = std::vector>{ + {49, 49, 64, 64, 4, 6}, + {64, 49, 64, 64, 4, 6}, + {1020, 1020, 64, 128, 4, 6}, + {576, 576, 64, 64, 4, 6}, + }; + this->Run(); +} diff --git a/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..912bbc91edd51d06c827fc31495d2955bea85a2b --- /dev/null +++ b/3rdparty/composable_kernel/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp" + +using ck::tensor_operation::device::GemmSpecialization; +using ck::tensor_operation::device::MaskingSpecialization; +using ck::tensor_operation::device::TensorSpecialization; + +template +using I = ck::Number; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test +{ + using NumDimGType = std::tuple_element_t<0, Tuple>; + using NumDimMType = std::tuple_element_t<1, Tuple>; + using NumDimNType = std::tuple_element_t<2, Tuple>; + using NumDimKType = std::tuple_element_t<3, Tuple>; + using NumDimOType = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using B0DataType = std::tuple_element_t<6, Tuple>; + using B1DataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using Acc0BiasDataType = std::tuple_element_t<9, Tuple>; + using Acc1BiasDataType = std::tuple_element_t<10, Tuple>; + using MaskingType = std::tuple_element_t<11, Tuple>; + + std::vector> lengths_ = { + {256, 256, 64, 64, 6, 4}, + {256, 256, 128, 128, 4, 6}, + {512, 512, 64, 64, 3, 2}, + {512, 512, 128, 128, 2, 3}, + {1024, 1024, 64, 64, 3, 1}, + {1024, 1024, 128, 128, 1, 1}, + }; + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int G0, int G1) + { + bool pass = + ck::profiler::profile_batched_gemm_softmax_gemm_permute_impl, + ck::Tuple<>, + MaskingType::value>( + verify_, 2, false, bench_, M, N, K, O, G0, G1); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int G0 = lengths[4]; + int G1 = lengths[5]; + + this->RunSingle(M, N, K, O, G0, G1); + } + } +}; + +template +struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Scale = ck::tensor_operation::element_wise::Scale; + + template + using S = ck::Sequence; + + using ADataType = F16; + using B0DataType = F16; + using B1DataType = F16; + using AccDataType = float; + using CShuffleDataType = F16; + using CDataType = F16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = Scale; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + 2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecialization::Default, // ATensorSpec + TensorSpecialization::Default, // B0TensorSpec + TensorSpecialization::Default, // B1TensorSpec + TensorSpecialization::Default, // CTensorSpec + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle + + bool IsSupported(int M, int N, int K, int O) + { + const int G0 = 1, G1 = 1; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + {}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + Scale{1.f}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; + +template +struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Scale = ck::tensor_operation::element_wise::Scale; + + template + using S = ck::Sequence; + + using ADataType = BF16; + using B0DataType = BF16; + using B1DataType = BF16; + using AccDataType = float; + using CShuffleDataType = BF16; + using CDataType = BF16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = Scale; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + 2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecialization::Default, // ATensorSpec + TensorSpecialization::Default, // B0TensorSpec + TensorSpecialization::Default, // B1TensorSpec + TensorSpecialization::Default, // CTensorSpec + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle + + bool IsSupported(int M, int N, int K, int O) + { + const int G0 = 1, G1 = 1; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + {}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + Scale{1.f}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; diff --git a/3rdparty/composable_kernel/test/batchnorm/CMakeLists.txt b/3rdparty/composable_kernel/test/batchnorm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..52f15086822cfed89a89a3850e2ad966b2e2d762 --- /dev/null +++ b/3rdparty/composable_kernel/test/batchnorm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_batchnorm_fwd_rank_4 batchnorm_fwd_rank_4.cpp) +add_gtest_executable(test_batchnorm_bwd_rank_4 batchnorm_bwd_rank_4.cpp) +target_link_libraries(test_batchnorm_fwd_rank_4 PRIVATE utility device_batchnorm_instance) +target_link_libraries(test_batchnorm_bwd_rank_4 PRIVATE utility device_batchnorm_instance) diff --git a/3rdparty/composable_kernel/test/batchnorm/batchnorm_bwd_rank_4.cpp b/3rdparty/composable_kernel/test/batchnorm/batchnorm_bwd_rank_4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..caa7331ea2c3e6dfe57a41272c5f4fba56662e39 --- /dev/null +++ b/3rdparty/composable_kernel/test/batchnorm/batchnorm_bwd_rank_4.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "profiler/profile_batchnorm_backward_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using F64 = double; + +template +class TestBatchNormBwdRank4 : public ::testing::Test +{ + private: + const double epsilon = std::numeric_limits::epsilon(); + + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using DxDataType = std::tuple_element_t<1, Tuple>; + using DyDataType = std::tuple_element_t<2, Tuple>; + using AccDataType = std::tuple_element_t<3, Tuple>; + using ScaleDataType = std::tuple_element_t<4, Tuple>; + using BiasDataType = std::tuple_element_t<5, Tuple>; + using MeanVarDataType = std::tuple_element_t<6, Tuple>; + + std::vector> list_of_lengths = { + {128, 16, 3, 1024}, {128, 16, 6, 512}, {1, 1, 1, 1}, {4, 4, 4, 4}, {32, 32, 32, 32}}; + std::vector reduceDims; + + template + void Run() + { + for(auto& inOutLengths : list_of_lengths) + { + bool pass = true; + + EXPECT_FALSE(reduceDims.size() != NumReduceDim); + + pass = pass && ck::profiler::profile_batchnorm_backward_impl( + true, 3, false, false, inOutLengths, reduceDims, true, epsilon); + + pass = pass && ck::profiler::profile_batchnorm_backward_impl( + true, 3, false, false, inOutLengths, reduceDims, false, epsilon); + + EXPECT_TRUE(pass); + } + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); + +// nhwc +TYPED_TEST(TestBatchNormBwdRank4, nhwc) +{ + this->reduceDims = {0, 1, 2}; + this->template Run<3>(); +} + +// nchw +TYPED_TEST(TestBatchNormBwdRank4, nchw) +{ + this->reduceDims = {0, 2, 3}; + this->template Run<3>(); +} diff --git a/3rdparty/composable_kernel/test/batchnorm/batchnorm_fwd_rank_4.cpp b/3rdparty/composable_kernel/test/batchnorm/batchnorm_fwd_rank_4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13aef7d6bfc9f222e663b951a2f9d44e81fc93bc --- /dev/null +++ b/3rdparty/composable_kernel/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "profiler/profile_batchnorm_forward_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F64 = double; + +template +class TestBatchNormFwdRank4 : public ::testing::Test +{ + private: + const double epsilon = std::numeric_limits::epsilon(); + const double averageFactor = 0.1; + + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using YDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using ScaleDataType = std::tuple_element_t<3, Tuple>; + using BiasDataType = std::tuple_element_t<4, Tuple>; + using MeanVarDataType = std::tuple_element_t<5, Tuple>; + + std::vector> list_of_lengths = { + {128, 16, 3, 1024}, {128, 16, 6, 512}, {1, 1, 1, 1}, {4, 4, 4, 4}, {32, 32, 32, 32}}; + std::vector reduceDims; + + template + void Run() + { + for(auto& inOutLengths : list_of_lengths) + { + bool pass = true; + + EXPECT_FALSE(reduceDims.size() != NumReduceDim); + + pass = + pass && ck::profiler::profile_batchnorm_forward_impl(true, + 3, + false, + false, + inOutLengths, + reduceDims, + true, + true, + epsilon, + averageFactor); + + pass = + pass && ck::profiler::profile_batchnorm_forward_impl(true, + 3, + false, + false, + inOutLengths, + reduceDims, + false, + false, + epsilon, + averageFactor); + + EXPECT_TRUE(pass); + } + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); + +// nhwc +TYPED_TEST(TestBatchNormFwdRank4, nhwc) +{ + this->reduceDims = {0, 1, 2}; + this->template Run<3>(); +} + +// nchw +TYPED_TEST(TestBatchNormFwdRank4, nchw) +{ + this->reduceDims = {0, 2, 3}; + this->template Run<3>(); +} diff --git a/3rdparty/composable_kernel/test/block_to_ctile_map/CMakeLists.txt b/3rdparty/composable_kernel/test/block_to_ctile_map/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..97dfbb2b552ae739a62db52bf154ddb54226bdb2 --- /dev/null +++ b/3rdparty/composable_kernel/test/block_to_ctile_map/CMakeLists.txt @@ -0,0 +1 @@ +add_gtest_executable(test_block_to_ctile_map test_block_to_ctile_map.cpp) \ No newline at end of file diff --git a/3rdparty/composable_kernel/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/3rdparty/composable_kernel/test/block_to_ctile_map/test_block_to_ctile_map.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55d9b59f489203cf68b5aebc1c5d9642801435c0 --- /dev/null +++ b/3rdparty/composable_kernel/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +using namespace ck; + +static auto I0 = Number<0>{}; +static auto I1 = Number<1>{}; +static auto I2 = Number<2>{}; + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 tile_map( + c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {0, 1, 1}, + {0, 2, 1}, + {0, 3, 0}, + {1, 0, 1}, + {1, 1, 1}, + {1, 2, 1}, + {1, 3, 0}, + {2, 0, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 3, 0}, + {3, 0, 0}, + {3, 1, 0}, + {3, 2, 0}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 + tile_map(c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false); +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 512; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 0}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 0}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 0}, + {0, 3, 1}, + {1, 3, 1}, + {2, 3, 1}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0) +{ + const index_t M = 512; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + // clang-format off + std::vector> expected_m0_gridsize_validity = { + {5, 15, false}, + {4, 12, true}, + {3, 18, false}, + {2, 12, true}, + {1, 12, true} + }; + // clang-format on + + for(auto e : expected_m0_gridsize_validity) + { + const index_t M01 = std::get<0>(e); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e)); + EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e)); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01Adapt tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 1}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 1}, + {4, 0, 1}, + {5, 0, 1}, + {4, 1, 1}, + {5, 1, 1}, + {4, 2, 1}, + {5, 2, 1}, + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + const index_t KSplit = 3; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_KSplit_M00_N0_M01Adapt + tile_map(c_grid_desc_m_n, M01, KSplit); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit); + + std::vector> expected_ksplitidx_m0idx_n0idx_valid = { + {0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1}, + {0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1}, + {0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1}, + {1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1}, + {1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1}, + {1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1}, + {2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1}, + {2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1}, + {2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1}, + }; + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", " + << ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2]; + std::cout << ", valid = " + << tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_ksplitidx_m0idx_n0idx_valid[i] == + std::vector{ksplitm0n0_idx[I0], + ksplitm0n0_idx[I1], + ksplitm0n0_idx[I2], + tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} diff --git a/3rdparty/composable_kernel/test/conv_util/CMakeLists.txt b/3rdparty/composable_kernel/test/conv_util/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7a46039f15f12df1961b1378146ec380c7e8b8b8 --- /dev/null +++ b/3rdparty/composable_kernel/test/conv_util/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_conv_util conv_util.cpp) +target_link_libraries(test_conv_util PRIVATE utility) diff --git a/3rdparty/composable_kernel/test/conv_util/conv_util.cpp b/3rdparty/composable_kernel/test/conv_util/conv_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73797a7169e78cfeabac101ac0c7da366187edbf --- /dev/null +++ b/3rdparty/composable_kernel/test/conv_util/conv_util.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" + +namespace { + +class TestConvUtil : public ::testing::Test +{ + public: + void SetNDParams(std::size_t ndims, std::size_t s, std::size_t d, std::size_t p) + { + conv_params = ck::utils::conv::ConvParam(ndims, + 2, + 128, + 192, + 256, + std::vector(ndims, 3), + std::vector(ndims, 71), + std::vector(ndims, s), + std::vector(ndims, d), + std::vector(ndims, p), + std::vector(ndims, p)); + } + + protected: + // ------- default 2D ------- + // input GNCHW {2, 128, 192, 71, 71}, + // weights GKCYX {2, 256, 192, 3, 3}, + // stride {s, s}, + // dilations {d, d}, + // padding {{p, p}, {p, p} + ck::utils::conv::ConvParam conv_params; +}; + +} // namespace + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) +{ + // stride 2, dilation 1, pad 1 + SetNDParams(1, 2, 1, 1); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); + + // stride 1, dilation 1, pad 1 + SetNDParams(1, 1, 1, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); + + // stride 2, dilation 1, pad 2 + SetNDParams(1, 2, 1, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}.")); + + // stride 2, dilation 2, pad 2 + SetNDParams(1, 2, 2, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); + + // stride 3, dilation 2, pad 1 + SetNDParams(1, 3, 2, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) +{ + // stride 2, dilation 1, pad 1 + SetNDParams(2, 2, 1, 1); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor.")); + + // stride 1, dilation 1, pad 1 + SetNDParams(2, 1, 1, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); + + // stride 2, dilation 1, pad 2 + SetNDParams(2, 2, 1, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}.")); + + // stride 2, dilation 2, pad 2 + SetNDParams(2, 2, 2, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); + + // stride 3, dilation 2, pad 1 + SetNDParams(2, 3, 2, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE( + ck::utils::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); +} + +TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) +{ + // stride 2, dilation 1, pad 1 + SetNDParams(3, 2, 1, 1); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); + + // stride 1, dilation 1, pad 1 + SetNDParams(3, 1, 1, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}.")); + + // stride 2, dilation 1, pad 2 + SetNDParams(3, 2, 1, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}.")); + + // stride 2, dilation 2, pad 2 + SetNDParams(3, 2, 2, 2); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}.")); + + // stride 3, dilation 2, pad 1 + SetNDParams(3, 3, 2, 1); + out_spatial_len = conv_params.GetOutputSpatialLengths(); + EXPECT_TRUE(ck::utils::check_err( + out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); +} diff --git a/3rdparty/composable_kernel/test/convnd_bwd_data/CMakeLists.txt b/3rdparty/composable_kernel/test/convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..16ca4de8727878b25c75357e4ccff9bbc0517955 --- /dev/null +++ b/3rdparty/composable_kernel/test/convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) diff --git a/3rdparty/composable_kernel/test/convnd_bwd_data/convnd_bwd_data.cpp b/3rdparty/composable_kernel/test/convnd_bwd_data/convnd_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70231d42ae5f9eb9a368c643b2cf4041ae5fac4e --- /dev/null +++ b/3rdparty/composable_kernel/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "profiler/profile_conv_bwd_data_impl.hpp" + +template +class TestConvndBwdData : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + std::vector conv_params; + + template + void Run() + { + for(auto& param : conv_params) + { + bool pass; + EXPECT_FALSE(conv_params.empty()); + pass = ck::profiler::profile_conv_bwd_data_impl< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + DataType, + DataType, + DataType>(true, // do_verification + 1, // init_method integer value + false, // do_log + false, // time_kernel + param); + EXPECT_TRUE(pass); + } + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; +TYPED_TEST_SUITE(TestConvndBwdData, KernelTypes); + +// 1d +TYPED_TEST(TestConvndBwdData, Conv1dBwdData) +{ + this->conv_params.clear(); + this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->template Run<1>(); +} + +// 2d +TYPED_TEST(TestConvndBwdData, Conv2dBwdData) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} + +// 3d +TYPED_TEST(TestConvndBwdData, Conv3dBwdData) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->template Run<3>(); +} diff --git a/3rdparty/composable_kernel/test/convnd_fwd/CMakeLists.txt b/3rdparty/composable_kernel/test/convnd_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..97e170d85118b641a8338cde918bf60bda9a3803 --- /dev/null +++ b/3rdparty/composable_kernel/test/convnd_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) +target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) diff --git a/3rdparty/composable_kernel/test/convnd_fwd/convnd_fwd.cpp b/3rdparty/composable_kernel/test/convnd_fwd/convnd_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1921a9bfbe0f1221ddada95c75835ab9c942292 --- /dev/null +++ b/3rdparty/composable_kernel/test/convnd_fwd/convnd_fwd.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "profiler/profile_conv_fwd_impl.hpp" + +template +class TestConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + std::vector conv_params; + + template + void Run() + { + for(auto& param : conv_params) + { + bool pass; + EXPECT_FALSE(conv_params.empty()); + pass = ck::profiler::profile_conv_fwd_impl< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + DataType, + DataType, + DataType>(true, // do_verification + 1, // init_method integer value + false, // do_log + false, // time_kernel + param); + EXPECT_TRUE(pass); + } + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; +TYPED_TEST_SUITE(TestConvndFwd, KernelTypes); + +// 1d +TYPED_TEST(TestConvndFwd, Conv1dFwd) +{ + this->conv_params.clear(); + this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->template Run<1>(); +} + +// 2d +TYPED_TEST(TestConvndFwd, Conv2dFwd) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 1, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} +// 3d +TYPED_TEST(TestConvndFwd, Conv3dFwd) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->template Run<3>(); +} diff --git a/3rdparty/composable_kernel/test/data_type/CMakeLists.txt b/3rdparty/composable_kernel/test/data_type/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..088fbfec719f9e948686637fe9bec280f9926bd2 --- /dev/null +++ b/3rdparty/composable_kernel/test/data_type/CMakeLists.txt @@ -0,0 +1,4 @@ +if (USE_BITINT_EXTENSION_INT4) + add_gtest_executable(test_int4 int4.cpp) + target_link_libraries(test_int4 PRIVATE utility) +endif() diff --git a/3rdparty/composable_kernel/test/data_type/int4.cpp b/3rdparty/composable_kernel/test/data_type/int4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..252a450bf96896e677cbc5f402ccfe722192d9bc --- /dev/null +++ b/3rdparty/composable_kernel/test/data_type/int4.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include "gtest/gtest.h" +#include + +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/library/utility/device_memory.hpp" + +using ck::int4_t; + +TEST(Int4, BaseArithmetic) +{ + int4_t a{1}; + int4_t b{-2}; + EXPECT_EQ(a + a, int4_t{2}); + EXPECT_EQ(a - a, int4_t{0}); + EXPECT_EQ(a + b, int4_t{-1}); + EXPECT_EQ(a - b, int4_t{3}); + EXPECT_EQ(a * a, int4_t{1}); + EXPECT_EQ(a * b, int4_t{-2}); + EXPECT_EQ(b * b, int4_t{4}); + EXPECT_EQ(a / b, int4_t{0}); + a = int4_t{4}; + EXPECT_EQ(a / b, int4_t{-2}); + b = int4_t{2}; + EXPECT_EQ(a % b, int4_t{0}); +} + +TEST(Int4, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), int4_t{-8}); + EXPECT_EQ(ck::NumericLimits::Max(), int4_t{7}); + EXPECT_EQ(ck::NumericLimits::Lowest(), int4_t{-8}); +} + +TEST(Int4, MathOpsV2) +{ + int4_t a{4}; + int4_t b{-5}; + + EXPECT_EQ(ck::math::abs(a), int4_t{4}); + EXPECT_EQ(ck::math::abs(b), int4_t{5}); + EXPECT_FALSE(ck::math::isnan(b)); +} + +namespace { + +__global__ void copy(const int4_t* src, std::int8_t* dst, ck::index_t N) +{ + ck::index_t tid = ck::get_thread_global_1d_id(); + + const int8_t* src_i8 = reinterpret_cast(src); + + if(tid < N) + { + for(ck::index_t i = tid; i < N; i += ck::get_grid_size()) + { + dst[i] = src_i8[i]; + } + } +} + +__global__ void copy_with_static_cast(const int4_t* src, std::int8_t* dst, ck::index_t N) +{ + ck::index_t tid = ck::get_thread_global_1d_id(); + + if(tid < N) + { + for(ck::index_t i = tid; i < N; i += ck::get_grid_size()) + { + dst[i] = static_cast(src[i]); + } + } +} + +} // anonymous namespace + +TEST(Int4, CopyAsI8PositiveValue) +{ + constexpr std::size_t SIZE = 100; + std::vector h_src_i4(SIZE, 7); + std::vector h_src_i8(SIZE, 7); + std::vector h_dst_i8(SIZE, 0); + + DeviceMem d_src_i4(h_src_i4.size() * sizeof(int4_t)); + DeviceMem d_dst_i8(h_dst_i8.size() * sizeof(std::int8_t)); + + d_src_i4.SetZero(); + d_dst_i8.SetZero(); + + d_src_i4.ToDevice(h_src_i4.data()); + + copy<<<1, 64>>>(reinterpret_cast(d_src_i4.GetDeviceBuffer()), + reinterpret_cast(d_dst_i8.GetDeviceBuffer()), + SIZE); + hip_check_error(hipDeviceSynchronize()); + d_dst_i8.FromDevice(h_dst_i8.data()); + + for(std::size_t i = 0; i < SIZE; ++i) + { + EXPECT_EQ(h_src_i8[i], h_dst_i8[i]); + } +} + +TEST(Int4, DISABLED_CopyAsI8NegativeValue) +{ + constexpr std::size_t SIZE = 32; + std::vector h_src_i4(SIZE, -8); + std::vector h_src_i8(SIZE, -8); + std::vector h_dst_i8(SIZE, 0); + + DeviceMem d_src_i4(h_src_i4.size() * sizeof(int4_t)); + DeviceMem d_dst_i8(h_dst_i8.size() * sizeof(std::int8_t)); + + d_src_i4.SetZero(); + d_dst_i8.SetZero(); + + d_src_i4.ToDevice(h_src_i4.data()); + + copy<<<1, 64>>>(reinterpret_cast(d_src_i4.GetDeviceBuffer()), + reinterpret_cast(d_dst_i8.GetDeviceBuffer()), + SIZE); + hip_check_error(hipDeviceSynchronize()); + d_dst_i8.FromDevice(h_dst_i8.data()); + + for(std::size_t i = 0; i < SIZE; ++i) + { + EXPECT_EQ(h_src_i8[i], h_dst_i8[i]); + } +} + +TEST(Int4, CopyAsI8NegativeValueStaticCast) +{ + constexpr std::size_t SIZE = 32; + std::vector h_src_i4(SIZE, -8); + std::vector h_src_i8(SIZE, -8); + std::vector h_dst_i8(SIZE, 0); + + DeviceMem d_src_i4(h_src_i4.size() * sizeof(int4_t)); + DeviceMem d_dst_i8(h_dst_i8.size() * sizeof(std::int8_t)); + + d_src_i4.SetZero(); + d_dst_i8.SetZero(); + + d_src_i4.ToDevice(h_src_i4.data()); + + copy_with_static_cast<<<1, 64>>>(reinterpret_cast(d_src_i4.GetDeviceBuffer()), + reinterpret_cast(d_dst_i8.GetDeviceBuffer()), + SIZE); + hip_check_error(hipDeviceSynchronize()); + d_dst_i8.FromDevice(h_dst_i8.data()); + + for(std::size_t i = 0; i < SIZE; ++i) + { + EXPECT_EQ(h_src_i8[i], h_dst_i8[i]); + } +} + +TEST(Int4, DISABLED_BitwiseRepresentation) +{ + using bit8_t = std::bitset<8>; + + int4_t a_i4{3}; + std::int8_t a_i8 = *reinterpret_cast(&a_i4); + std::int8_t b_i8{3}; +#if 0 + std::cout << std::hex << std::showbase << static_cast(a_i8) + << ", " << static_cast(b_i8) << std::endl; +#endif + EXPECT_EQ(bit8_t{static_cast(a_i8)}, bit8_t{static_cast(b_i8)}); + + a_i4 = int4_t{-3}; + a_i8 = *reinterpret_cast(&a_i4); + b_i8 = std::int8_t{-3}; +#if 0 + std::cout << std::hex << std::showbase << static_cast(a_i8) + << ", " << static_cast(b_i8) << std::endl; +#endif + EXPECT_EQ(bit8_t{static_cast(a_i8)}, bit8_t{static_cast(b_i8)}); +} + +TEST(Int4, BitwiseRepresentationStaticCast) +{ + using bit8_t = std::bitset<8>; + + int4_t a_i4{3}; + std::int8_t a_i8 = static_cast(a_i4); + std::int8_t b_i8{3}; +#if 0 + std::cout << std::hex << std::showbase << static_cast(a_i8) + << ", " << static_cast(b_i8) << std::endl; +#endif + EXPECT_EQ(bit8_t{static_cast(a_i8)}, bit8_t{static_cast(b_i8)}); + + a_i4 = int4_t{-3}; + a_i8 = static_cast(a_i4); + b_i8 = std::int8_t{-3}; +#if 0 + std::cout << std::hex << std::showbase << static_cast(a_i8) + << ", " << static_cast(b_i8) << std::endl; +#endif + EXPECT_EQ(bit8_t{static_cast(a_i8)}, bit8_t{static_cast(b_i8)}); +} diff --git a/3rdparty/composable_kernel/test/elementwise_normalization/CMakeLists.txt b/3rdparty/composable_kernel/test/elementwise_normalization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a20eb263256af119453c81f399ea3605bd8b71cc --- /dev/null +++ b/3rdparty/composable_kernel/test/elementwise_normalization/CMakeLists.txt @@ -0,0 +1,7 @@ +add_custom_target(test_elementwise_normalization) + +add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) + +target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance) + +add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16) diff --git a/3rdparty/composable_kernel/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp b/3rdparty/composable_kernel/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..403881b3cc442d0c16779b99ac678b7e30111cd4 --- /dev/null +++ b/3rdparty/composable_kernel/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_elementwise_layernorm_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestElementwiseLayernorm : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using GammaDataType = std::tuple_element_t<2, Tuple>; + using BetaDataType = std::tuple_element_t<3, Tuple>; + using AccDataType = std::tuple_element_t<4, Tuple>; + using YDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // M, N + std::vector> lengths = { + {1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_elementwise_layernorm_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // ADataType, BDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestElementwiseLayernorm, KernelTypes); +TYPED_TEST(TestElementwiseLayernorm, Test_FP16) { this->Run(); } diff --git a/3rdparty/composable_kernel/test/gemm/CMakeLists.txt b/3rdparty/composable_kernel/test/gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c427586bb7933c2affed0c54b043f4e985346928 --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/CMakeLists.txt @@ -0,0 +1,25 @@ +add_test_executable(test_gemm_fp32 gemm_fp32.cpp) +target_link_libraries(test_gemm_fp32 PRIVATE utility) +target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_fp16 gemm_fp16.cpp) +target_link_libraries(test_gemm_fp16 PRIVATE utility) +target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_bf16 gemm_bf16.cpp) +target_link_libraries(test_gemm_bf16 PRIVATE utility) +target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_int8 gemm_int8.cpp) +target_link_libraries(test_gemm_int8 PRIVATE utility) +target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) + +add_library(gemm_standalone_xdl_fp16_instances STATIC + instance/gemm_f16_nn_instance.cpp + instance/gemm_f16_nt_instance.cpp + instance/gemm_f16_tn_instance.cpp + instance/gemm_f16_tt_instance.cpp +) +add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) +target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) +target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) diff --git a/3rdparty/composable_kernel/test/gemm/gemm_bf16.cpp b/3rdparty/composable_kernel/test/gemm/gemm_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5290d466323277daa4e61af1a8273b816a790120 --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_bf16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; + +#include "run_gemm_test.inc" + +int main() { return run_gemm_test(); } diff --git a/3rdparty/composable_kernel/test/gemm/gemm_fp16.cpp b/3rdparty/composable_kernel/test/gemm/gemm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92e225def29623c7dcfc8590a5f805052dedb54d --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +#include "run_gemm_test.inc" + +int main() { return run_gemm_test(); } diff --git a/3rdparty/composable_kernel/test/gemm/gemm_fp32.cpp b/3rdparty/composable_kernel/test/gemm/gemm_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d8c4881b621504fe621d056cbcf68bcb04622fc --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_fp32.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +#include "run_gemm_test.inc" + +int main() { return run_gemm_test(); } diff --git a/3rdparty/composable_kernel/test/gemm/gemm_fp64.cpp b/3rdparty/composable_kernel/test/gemm/gemm_fp64.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85d7f95bf4ad9ea8df19fdf2600c191f7eabcbf2 --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_fp64.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +using ADataType = double; +using BDataType = double; +using CDataType = double; +using AccDataType = double; + +#include "run_gemm_test.inc" + +int main() { return run_gemm_test(); } diff --git a/3rdparty/composable_kernel/test/gemm/gemm_int8.cpp b/3rdparty/composable_kernel/test/gemm/gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e73b22ce9c847f704fe55ad31cc83661b1ebab87 --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_int8.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "test/gemm/gemm_util.hpp" + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +#include "run_gemm_test.inc" + +int main() { return run_gemm_test(); } diff --git a/3rdparty/composable_kernel/test/gemm/gemm_standalone_xdl_fp16.cpp b/3rdparty/composable_kernel/test/gemm/gemm_standalone_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f5a5c557cdc4694895b67f18d6df293dece771e --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_standalone_xdl_fp16.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_util.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +#include "gemm_f16_nn_instance.hpp" +#include "gemm_f16_nt_instance.hpp" +#include "gemm_f16_tn_instance.hpp" +#include "gemm_f16_tt_instance.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using F16 = ck::half_t; +using ADataType = F16; +using BDataType = F16; +using AccDataType = float; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ck::gemm_util::GemmParams; +using ck::tensor_operation::device::BaseOperator; +using ck::tensor_operation::device::DeviceGemm; +using namespace ck::tensor_operation::device::instance; + +using DeviceGemmNN = + DeviceGemm; +using DeviceGemmNT = + DeviceGemm; +using DeviceGemmTN = + DeviceGemm; +using DeviceGemmTT = + DeviceGemm; + +struct LayoutConfig +{ + bool ARowMajor; + bool BRowMajor; + bool CRowMajor; +}; + +int main(int argc, char* argv[]) +{ + // Class DeviceGemm is templated by layout and precision types so it is not an option to contain + // them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it + // upon invocation. + // And since DeviceGemm does not expose template arg information, an extra book keeping class + // LayoutConfig is used for determining which type a BaseOperator instance should be cast to. + using OpFactoryFn = void (*)(std::vector>&); + + std::vector> problems = { + // clang-format off + // 104 tiles + {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, + {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, + {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, + {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, + {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, + // 110 tiles + {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, + {GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, + {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, + {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, + {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, + // clang-format on + }; + + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: time kernel (0=no, 1=yes)" << std::endl; + return 0; + } + + bool pass = true; + for(auto& p : problems) + { + GemmParams& problem_size = std::get<0>(p); + const LayoutConfig& layout_config = std::get<1>(p); + const auto& factory = std::get<2>(p); + std::vector> ops; + factory(ops); + + // overwrite strides + problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M; + problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K; + problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M; + + if(!layout_config.ARowMajor && !layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(!layout_config.ARowMajor && layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(layout_config.ARowMajor && !layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + else if(layout_config.ARowMajor && layout_config.BRowMajor) + { + auto op_ptr = dynamic_cast(ops[0].get()); + pass &= ck::gemm_util::TestGemm{}( + op_ptr, problem_size, do_verification, time_kernel); + } + } + + std::cout << (pass ? "ALL TESTS PASSED" : "SOME TESTS FAILED") << std::endl; + return pass ? 0 : 1; +} diff --git a/3rdparty/composable_kernel/test/gemm/gemm_util.hpp b/3rdparty/composable_kernel/test/gemm/gemm_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9057c0af896891107683605575ee17aeb460c6cf --- /dev/null +++ b/3rdparty/composable_kernel/test/gemm/gemm_util.hpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace gemm_util { + +struct GemmParams +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, + const ck::gemm_util::GemmParams& params, + const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + bool time_kernel) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + auto invoker_ptr = gemmPtr->MakeInvokerPointer(); + auto argument_ptr = + gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + params.M, + params.N, + params.K, + params.StrideA, + params.StrideB, + params.StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) + { + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * params.M * params.N * params.K; + std::size_t num_btype = sizeof(ADataType) * params.M * params.K + + sizeof(BDataType) * params.K * params.N + + sizeof(CDataType) * params.M * params.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << std::endl; + + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; + } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return false; + } +} + +template +struct TestGemm +{ + template + auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_k_n, BDataType{}); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); + } + + template